From 136fca52639ea1b04bdee620d0b5c822f42fe983 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 10 Jul 2024 16:34:53 +0100 Subject: [PATCH 01/37] Create a Mixin for some Spectrum1DCollection methods, rewrite select() - This version of select() should be more robust in dealing with parameters that exist in "top level" of metadata dict - I hope it is also easier to understand --- euphonic/spectra.py | 180 +++++++++++++----- .../test_spectrum1dcollection.py | 4 +- 2 files changed, 134 insertions(+), 50 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 6d4a89d85..acf5c58e8 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -7,8 +7,9 @@ import math import json from numbers import Integral, Real -from typing import (Any, Callable, Dict, List, Literal, Optional, overload, +from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) +from typing_extensions import Self import warnings from pint import DimensionalityError, Quantity @@ -669,10 +670,91 @@ def broaden(self: T, x_width, return new_spectrum -LineData = Sequence[Dict[str, Union[str, int]]] +OneLineData = Dict[str, Union[str, int]] +LineData = Sequence[OneLineData] +Metadata = Dict[str, Union[str, int, LineData]] -class Spectrum1DCollection(collections.abc.Sequence, Spectrum): +class SpectrumCollectionMixin: + """Help a collection of spectra work with "line_data" metadata file + + This is a Mixin to be inherited by Spectrum collection classes + + To avoid redundancy, spectrum collections store metadata in the form + + {"key1": value1, "key2", value2, "line_data": [{"key3": value3, ...}, + {"key4": value4, ...}...]} + + - It is not guaranteed that all "lines" carry the same keys + - No key should appear at both top-level and in line-data; any key-value + pair at top level is assumed to apply to all lines + - "lines" can actually correspond to N-D spectra, the notation was devised + for multi-line plots of Spectrum1DCollection and then applied to other + purposes. + + """ + + def iter_metadata(self) -> Generator[OneLineData, None, None]: + """Iterate over metadata dicts of individual spectra from collection""" + common_metadata = dict((key, self.metadata[key]) for key in self.metadata.keys() - set("line_data")) + from itertools import repeat + + line_data = self.metadata.get("line_data") + if line_data is None: + line_data = repeat({}, len(self._z_data)) + + for one_line_data in line_data: + yield common_metadata | one_line_data + + def _select_indices(self, **select_key_values) -> list[int]: + required_metadata = select_key_values.items() + indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()] + return indices + + def select(self, **select_key_values: Union[ + str, int, Sequence[str], Sequence[int]]) -> Self: + """ + Select spectra by their keys and values in metadata['line_data'] + + Parameters + ---------- + **select_key_values + Key-value/values pairs in metadata['line_data'] describing + which spectra to extract. For example, to select all spectra + where metadata['line_data']['species'] = 'Na' or 'Cl' use + spectrum.select(species=['Na', 'Cl']). To select 'Na' and + 'Cl' spectra where weighting is also coherent, use + spectrum.select(species=['Na', 'Cl'], weighting='coherent') + + Returns + ------- + selected_spectra + A Spectrum1DCollection containing the selected spectra + + Raises + ------ + ValueError + If no matching spectra are found + """ + # Convert all items to sequences of possibilities + select_key_values = dict( + (key, (value,)) if isinstance(value, (int, str)) else (key, value) + for key, value in select_key_values.items() + ) + + # Collect indices that match each combination of values + selected_indices = [] + for value_combination in itertools.product(*select_key_values.values()): + selection = dict(zip(select_key_values.keys(), value_combination)) + selected_indices.extend(self._select_indices(**selection)) + + if not selected_indices: + raise ValueError(f'No spectra found with matching metadata ' + f'for {select_key_values}') + + return self[selected_indices] + +class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): """A collection of Spectrum1D with common x_data and x_tick_labels Intended for convenient storage of band structures, projected DOS @@ -1201,52 +1283,52 @@ def sum(self) -> Spectrum1D: x_tick_labels=copy.copy(self.x_tick_labels), metadata=copy.deepcopy(metadata)) - def select(self, **select_key_values: Union[ - str, int, Sequence[str], Sequence[int]]) -> T: - """ - Select spectra by their keys and values in metadata['line_data'] - - Parameters - ---------- - **select_key_values - Key-value/values pairs in metadata['line_data'] describing - which spectra to extract. For example, to select all spectra - where metadata['line_data']['species'] = 'Na' or 'Cl' use - spectrum.select(species=['Na', 'Cl']). To select 'Na' and - 'Cl' spectra where weighting is also coherent, use - spectrum.select(species=['Na', 'Cl'], weighting='coherent') - - Returns - ------- - selected_spectra - A Spectrum1DCollection containing the selected spectra - - Raises - ------ - ValueError - If no matching spectra are found - """ - select_val_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*select_key_values.keys())) - for key, value in select_key_values.items(): - if isinstance(value, (int, str)): - select_key_values[key] = [value] - value_combinations = itertools.product(*select_key_values.values()) - select_idx = np.array([], dtype=np.int32) - for value_combo in value_combinations: - try: - idx = select_val_dict[value_combo] - # Don't require every combination to match e.g. - # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) - # we don't want to error simply because there are no - # inst='MAPS' and sample=2 combinations - except KeyError: - continue - select_idx = np.append(select_idx, idx) - if len(select_idx) == 0: - raise ValueError(f'No spectra found with matching metadata ' - f'for {select_key_values}') - return self[select_idx] + # def select(self, **select_key_values: Union[ + # str, int, Sequence[str], Sequence[int]]) -> T: + # """ + # Select spectra by their keys and values in metadata['line_data'] + + # Parameters + # ---------- + # **select_key_values + # Key-value/values pairs in metadata['line_data'] describing + # which spectra to extract. For example, to select all spectra + # where metadata['line_data']['species'] = 'Na' or 'Cl' use + # spectrum.select(species=['Na', 'Cl']). To select 'Na' and + # 'Cl' spectra where weighting is also coherent, use + # spectrum.select(species=['Na', 'Cl'], weighting='coherent') + + # Returns + # ------- + # selected_spectra + # A Spectrum1DCollection containing the selected spectra + + # Raises + # ------ + # ValueError + # If no matching spectra are found + # """ + # select_val_dict = _get_unique_elems_and_idx( + # self._get_line_data_vals(*select_key_values.keys())) + # for key, value in select_key_values.items(): + # if isinstance(value, (int, str)): + # select_key_values[key] = [value] + # value_combinations = itertools.product(*select_key_values.values()) + # select_idx = np.array([], dtype=np.int32) + # for value_combo in value_combinations: + # try: + # idx = select_val_dict[value_combo] + # # Don't require every combination to match e.g. + # # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) + # # we don't want to error simply because there are no + # # inst='MAPS' and sample=2 combinations + # except KeyError: + # continue + # select_idx = np.append(select_idx, idx) + # if len(select_idx) == 0: + # raise ValueError(f'No spectra found with matching metadata ' + # f'for {select_key_values}') + # return self[select_idx] class Spectrum2D(Spectrum): diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py index 24aedbd5e..cca15fe5d 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py @@ -623,7 +623,9 @@ def test_select(self, spectrum_file, select_kwargs, [3, 5]), ('La2Zr2O7_666_coh_incoh_species_append_pdos.json', {'weighting': 'incoherent', 'species': 'O'}, - [3]) + [3]), + ('methane_pdos.json', + {'desc': 'Methane PDOS', 'label': 'H3'}, [2]), ]) def test_select_same_as_indexing(self, spectrum_file, select_kwargs, expected_indices): From e825e6907b902a62af4eb12b85eff3be8337c6ac Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 12 Jul 2024 13:00:03 +0100 Subject: [PATCH 02/37] Tidying up: rename and move private metadata-handling methods --- euphonic/spectra.py | 149 +++++++++++++++++--------------------------- 1 file changed, 56 insertions(+), 93 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index acf5c58e8..07aeb75df 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -754,6 +754,52 @@ def select(self, **select_key_values: Union[ return self[selected_indices] + @staticmethod + def _combine_metadata(all_metadata: LineData) -> Metadata: + """ + From a sequence of metadata dictionaries, combines all common + key/value pairs into the top level of a metadata dictionary, + all unmatching key/value pairs are put into the 'line_data' + key, which is a list of metadata dicts for each element in + all_metadata + """ + # This is for combining multiple separate spectrum metadata, + # they shouldn't have line_data + for metadata in all_metadata: + assert 'line_data' not in metadata.keys() + + # Combine all common key/value pairs into new dict + combined_metadata = dict( + set(all_metadata[0].items()).intersection( + *[metadata.items() for metadata in all_metadata[1:]])) + + # Put all other per-spectrum metadata in line_data + line_data = [ + {key: value for key, value in metadata.items() + if key not in combined_metadata} + for metadata in all_metadata + ] + if any(line_data): + combined_metadata['line_data'] = line_data + + return combined_metadata + + def _tidy_metadata(self, indices: Optional[Sequence[int]] = None + ) -> Metadata: + """ + For a metadata dictionary, combines all common key/value + pairs in 'line_data' and puts them in a top-level dictionary. + If indices is supplied, only those indices in 'line_data' are + combined. Unmatching key/value pairs are discarded + """ + line_data = self.metadata.get("line_data", [{}] * len(self)) + if indices is not None: + line_data = [line_data[idx] for idx in indices] + combined_line_data = self._combine_metadata(line_data) + combined_line_data.pop("line_data", None) + return combined_line_data + + class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): """A collection of Spectrum1D with common x_data and x_tick_labels @@ -944,50 +990,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - @staticmethod - def _combine_metadata(all_metadata: Sequence[Dict[str, Union[int, str]]] - ) -> Dict[str, Union[int, str, LineData]]: - """ - From a sequence of metadata dictionaries, combines all common - key/value pairs into the top level of a metadata dictionary, - all unmatching key/value pairs are put into the 'line_data' - key, which is a list of metadata dicts for each element in - all_metadata - """ - # This is for combining multiple separate spectrum metadata, - # they shouldn't have line_data - for metadata in all_metadata: - assert 'line_data' not in metadata.keys() - # Combine all common key/value pairs - combined_metadata = dict( - set(all_metadata[0].items()).intersection( - *[metadata.items() for metadata in all_metadata[1:]])) - # Put all other per-spectrum metadata in line_data - line_data = [] - for i, metadata in enumerate(all_metadata): - sdata = copy.deepcopy(metadata) - for key in combined_metadata.keys(): - sdata.pop(key) - line_data.append(sdata) - if any(line_data): - combined_metadata['line_data'] = line_data - return combined_metadata - - def _combine_line_metadata(self, indices: Optional[Sequence[int]] = None - ) -> Dict[str, Any]: - """ - For a metadata dictionary, combines all common key/value - pairs in 'line_data' and puts them in a top-level dictionary. - If indices is supplied, only those indices in 'line_data' are - combined. Unmatching key/value pairs are discarded - """ - line_data = self.metadata.get('line_data', [{}]*len(self)) - if indices is not None: - line_data = [line_data[idx] for idx in indices] - combined_line_data = self._combine_metadata(line_data) - combined_line_data.pop('line_data', None) - return combined_line_data - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: """ Get value of the key(s) for each element in @@ -1242,6 +1244,14 @@ def group_by(self, *line_data_keys: str) -> T: metadata in 'line_data' not common across all spectra in a group will be discarded """ + # Remove line_data_keys that are not found in top level of metadata: + # these will not be useful for grouping + keys = [key for key in line_data_keys if key not in self.metadata] + + # If there are no keys left, sum everything as one big group and return + if not keys: + return self.from_spectra([self.sum()]) + grouping_dict = _get_unique_elems_and_idx( self._get_line_data_vals(*line_data_keys)) @@ -1250,7 +1260,7 @@ def group_by(self, *line_data_keys: str) -> T: group_metadata['line_data'] = [{}]*len(grouping_dict) for i, idxs in enumerate(grouping_dict.values()): # Look for any common key/values in grouped metadata - group_i_metadata = self._combine_line_metadata(idxs) + group_i_metadata = self._tidy_metadata(idxs) group_metadata['line_data'][i] = group_i_metadata new_y_data[i] = np.sum(self._y_data[idxs], axis=0) new_y_data = new_y_data*ureg(self._internal_y_data_unit).to( @@ -1275,7 +1285,7 @@ def sum(self) -> Spectrum1D: """ metadata = copy.deepcopy(self.metadata) metadata.pop('line_data', None) - metadata.update(self._combine_line_metadata()) + metadata.update(self._tidy_metadata()) summed_y_data = np.sum(self._y_data, axis=0)*ureg( self._internal_y_data_unit).to(self.y_data_unit) return Spectrum1D(np.copy(self.x_data), @@ -1283,53 +1293,6 @@ def sum(self) -> Spectrum1D: x_tick_labels=copy.copy(self.x_tick_labels), metadata=copy.deepcopy(metadata)) - # def select(self, **select_key_values: Union[ - # str, int, Sequence[str], Sequence[int]]) -> T: - # """ - # Select spectra by their keys and values in metadata['line_data'] - - # Parameters - # ---------- - # **select_key_values - # Key-value/values pairs in metadata['line_data'] describing - # which spectra to extract. For example, to select all spectra - # where metadata['line_data']['species'] = 'Na' or 'Cl' use - # spectrum.select(species=['Na', 'Cl']). To select 'Na' and - # 'Cl' spectra where weighting is also coherent, use - # spectrum.select(species=['Na', 'Cl'], weighting='coherent') - - # Returns - # ------- - # selected_spectra - # A Spectrum1DCollection containing the selected spectra - - # Raises - # ------ - # ValueError - # If no matching spectra are found - # """ - # select_val_dict = _get_unique_elems_and_idx( - # self._get_line_data_vals(*select_key_values.keys())) - # for key, value in select_key_values.items(): - # if isinstance(value, (int, str)): - # select_key_values[key] = [value] - # value_combinations = itertools.product(*select_key_values.values()) - # select_idx = np.array([], dtype=np.int32) - # for value_combo in value_combinations: - # try: - # idx = select_val_dict[value_combo] - # # Don't require every combination to match e.g. - # # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) - # # we don't want to error simply because there are no - # # inst='MAPS' and sample=2 combinations - # except KeyError: - # continue - # select_idx = np.append(select_idx, idx) - # if len(select_idx) == 0: - # raise ValueError(f'No spectra found with matching metadata ' - # f'for {select_key_values}') - # return self[select_idx] - class Spectrum2D(Spectrum): """ From 067339ab55f02b849936e6c6292695b0aeba4c3f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 12 Jul 2024 15:00:10 +0100 Subject: [PATCH 03/37] More refactoring in preparation for Spectrum2DCollection --- euphonic/spectra.py | 325 ++++++++++++++++++++++++++------------------ 1 file changed, 195 insertions(+), 130 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 07aeb75df..4921fb734 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -675,7 +675,7 @@ def broaden(self: T, x_width, Metadata = Dict[str, Union[str, int, LineData]] -class SpectrumCollectionMixin: +class SpectrumCollectionMixin(ABC): """Help a collection of spectra work with "line_data" metadata file This is a Mixin to be inherited by Spectrum collection classes @@ -692,23 +692,124 @@ class SpectrumCollectionMixin: for multi-line plots of Spectrum1DCollection and then applied to other purposes. + The _spectrum_axis class attribute determines which axis property contains + the spectral data, and should be set by subclasses (i.e. to "y" or "z" for + 1D or 2D). """ + # Subclasses must define which axis contains the spectral data for + # purposes of splitting, indexing, etc. + # Python doesn't support abstract class attributes so we define a default + # value, ensuring _something_ was set. + _bin_axes = ("x",) + _spectrum_axis = "y" + _item_type = Spectrum1D + + # Define some private methods which wrap this information into useful forms + def _spectrum_data_name(self) -> str: + return f"{self._spectrum_axis}_data" + + def _spectrum_raw_data_name(self) -> str: + return f"_{self._spectrum_axis}_data" + + def _get_spectrum_data(self) -> Quantity: + return getattr(self, self._spectrum_data_name()) + + def _get_raw_spectrum_data(self) -> np.ndarray: + return getattr(self, self._spectrum_raw_data_name()) + + def _set_spectrum_data(self, data: Quantity) -> None: + setattr(self, self._spectrum_data_name(), data) + + def _set_raw_spectrum_data(self, data: np.ndarray) -> None: + setattr(self, self._spectrum_raw_data_name(), data) + + def _get_spectrum_data_unit(self) -> str: + return getattr(self, f"{self._spectrum_data_name()}_unit") + + def _get_internal_spectrum_data_unit(self) -> str: + return getattr(self, f"_internal_{self._spectrum_data_name()}_unit") + + def _get_bin_kwargs(self) -> Dict[str, Quantity]: + """Get constructor args for bin axes from current data + + e.g. for Spectrum2DCollection this is + + {"x_data": self.x_data, "y_data": self.y_data} + """ + return {f"{axis}_data": getattr(self, f"{axis}_data") + for axis in self._bin_axes} + + def sum(self) -> Spectrum: + """ + Sum collection to a single spectrum + + Returns + ------- + summed_spectrum + A single combined spectrum from all items in collection. Any + metadata in 'line_data' not common across all spectra will be + discarded + """ + metadata = copy.deepcopy(self.metadata) + metadata.pop('line_data', None) + metadata.update(self._tidy_metadata()) + summed_s_data = np.sum(self._get_raw_spectrum_data(), axis=0 + ) * ureg(self._get_internal_spectrum_data_unit() + ).to(self._get_spectrum_data_unit()) + return Spectrum1D( + **self._get_bin_kwargs(), + **{self._spectrum_data_name(): summed_s_data}, + x_tick_labels=copy.copy(self.x_tick_labels), + metadata=metadata + ) + + + # Required methods + @classmethod + @abstractmethod + def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: ... + + # Mixin methods + def __len__(self): + return self._get_raw_spectrum_data().shape[0] + + def copy(self) -> Self: + """Get an independent copy of spectrum""" + return self._item_type.copy(self) + + def __add__(self, other: Self) -> Self: + """ + Appends the y_data of 2 Spectrum1DCollection objects, + creating a single Spectrum1DCollection that contains + the spectra from both objects. The two objects must + have equal x_data axes, and their y_data must + have compatible units and the same number of y_data + entries + + Any metadata key/value pairs that are common to both + spectra are retained in the top level dictionary, any + others are put in the individual 'line_data' entries + """ + return type(self).from_spectra([*self, *other]) + def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" - common_metadata = dict((key, self.metadata[key]) for key in self.metadata.keys() - set("line_data")) - from itertools import repeat + common_metadata = dict( + (key, self.metadata[key]) + for key in self.metadata.keys() - set("line_data")) line_data = self.metadata.get("line_data") if line_data is None: - line_data = repeat({}, len(self._z_data)) + line_data = itertools.repeat({}, len(self._z_data)) for one_line_data in line_data: yield common_metadata | one_line_data def _select_indices(self, **select_key_values) -> list[int]: required_metadata = select_key_values.items() - indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()] + indices = [i for i, row in enumerate(self.iter_metadata()) + if required_metadata <= row.items()] return indices def select(self, **select_key_values: Union[ @@ -799,8 +900,89 @@ def _tidy_metadata(self, indices: Optional[Sequence[int]] = None combined_line_data.pop("line_data", None) return combined_line_data + def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: + """ + Get value of the key(s) for each element in + metadata['line_data']. Returns a 1D array of tuples, where each + tuple contains the value(s) for each key in line_data_keys, for + a single element in metadata['line_data']. This allows easy + grouping/selecting by specific keys + + For example, if we have a Spectrum1DCollection with the following + metadata: + {'desc': 'Quartz', 'line_data': [ + {'inst': 'LET', 'sample': 0, 'index': 1}, + {'inst': 'MAPS', 'sample': 1, 'index': 2}, + {'inst': 'MARI', 'sample': 1, 'index': 1}, + ]} + Then: + _get_line_data_vals('inst', 'sample') = [('LET', 0), + ('MAPS', 1), + ('MARI', 1)] + + Raises a KeyError if 'line_data' or the key doesn't exist + """ + line_data = self.metadata['line_data'] + line_data_vals = np.empty(len(line_data), dtype=object) + for i, data in enumerate(line_data): + line_data_vals[i] = tuple([data[key] for key in line_data_keys]) + return line_data_vals + + def group_by(self, *line_data_keys: str) -> Self: + """ + Group and sum elements of spectral data according to the values + mapped to the specified keys in metadata['line_data'] + + Parameters + ---------- + line_data_keys + The key(s) to group by. If only one line_data_key is + supplied, if the value mapped to a key is the same for + multiple spectra, they are placed in the same group and + summed. If multiple line_data_keys are supplied, the values + must be the same for all specified keys for them to be + placed in the same group + + Returns + ------- + grouped_spectrum + A new Spectrum1DCollection with one line for each group. Any + metadata in 'line_data' not common across all spectra in a + group will be discarded + """ + # Remove line_data_keys that are not found in top level of metadata: + # these will not be useful for grouping + keys = [key for key in line_data_keys if key not in self.metadata] + + # If there are no keys left, sum everything as one big group and return + if not keys: + return self.from_spectra([self.sum()]) + + grouping_dict = _get_unique_elems_and_idx( + self._get_line_data_vals(*line_data_keys)) + + new_s_data = np.zeros((len(grouping_dict), + *self._get_raw_spectrum_data().shape[1:])) + group_metadata = copy.deepcopy(self.metadata) + group_metadata['line_data'] = [{}]*len(grouping_dict) + for i, idxs in enumerate(grouping_dict.values()): + # Look for any common key/values in grouped metadata + group_i_metadata = self._tidy_metadata(idxs) + group_metadata['line_data'][i] = group_i_metadata + new_s_data[i] = np.sum(self._get_raw_spectrum_data()[idxs], axis=0) + new_s_data = new_s_data*ureg(self._get_internal_spectrum_data_unit()).to( + self._get_spectrum_data_unit()) + + new_data = self.copy() + new_data._set_spectrum_data(new_s_data) + new_data.metadata = group_metadata + + return new_data + -class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): +class Spectrum1DCollection(SpectrumCollectionMixin, + Spectrum, + collections.abc.Sequence): """A collection of Spectrum1D with common x_data and x_tick_labels Intended for convenient storage of band structures, projected DOS @@ -834,6 +1016,10 @@ class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Sp """ T = TypeVar('T', bound='Spectrum1DCollection') + # Private attributes used by SpectrumCollectionMixin + _spectrum_axis = "y" + _item_type = Spectrum1D + def __init__( self, x_data: Quantity, y_data: Quantity, x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, @@ -885,24 +1071,9 @@ def __init__( f'{len(metadata["line_data"])} entries') self.metadata = {} if metadata is None else metadata - def __add__(self: T, other: T) -> T: - """ - Appends the y_data of 2 Spectrum1DCollection objects, - creating a single Spectrum1DCollection that contains - the spectra from both objects. The two objects must - have equal x_data axes, and their y_data must - have compatible units and the same number of y_data - entries - - Any metadata key/value pairs that are common to both - spectra are retained in the top level dictionary, any - others are put in the individual 'line_data' entries - """ - return type(self).from_spectra([*self, *other]) - def _split_by_indices(self, indices: Union[Sequence[int], np.ndarray] - ) -> List[T]: + ) -> List[Self]: """Split data along x-axis at given indices""" ranges = self._ranges_from_indices(indices) @@ -913,19 +1084,16 @@ def _split_by_indices(self, metadata=self.metadata) for x0, x1 in ranges] - def __len__(self): - return self.y_data.shape[0] - @overload def __getitem__(self, item: int) -> Spectrum1D: ... @overload # noqa: F811 - def __getitem__(self, item: slice) -> T: + def __getitem__(self, item: slice) -> Self: ... @overload # noqa: F811 - def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> T: + def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> Self: ... def __getitem__(self, item: Union[int, slice, Sequence[int], np.ndarray] @@ -990,38 +1158,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: - """ - Get value of the key(s) for each element in - metadata['line_data']. Returns a 1D array of tuples, where each - tuple contains the value(s) for each key in line_data_keys, for - a single element in metadata['line_data']. This allows easy - grouping/selecting by specific keys - - For example, if we have a Spectrum1DCollection with the following - metadata: - {'desc': 'Quartz', 'line_data': [ - {'inst': 'LET', 'sample': 0, 'index': 1}, - {'inst': 'MAPS', 'sample': 1, 'index': 2}, - {'inst': 'MARI', 'sample': 1, 'index': 1}, - ]} - Then: - _get_line_data_vals('inst', 'sample') = [('LET', 0), - ('MAPS', 1), - ('MARI', 1)] - - Raises a KeyError if 'line_data' or the key doesn't exist - """ - line_data = self.metadata['line_data'] - line_data_vals = np.empty(len(line_data), dtype=object) - for i, data in enumerate(line_data): - line_data_vals[i] = tuple([data[key] for key in line_data_keys]) - return line_data_vals - - def copy(self: T) -> T: - """Get an independent copy of spectrum""" - return Spectrum1D.copy(self) - def to_dict(self) -> Dict[str, Any]: """ Convert to a dictionary consistent with from_dict() @@ -1222,77 +1358,6 @@ def broaden(self: T, else: raise TypeError("x_width must be a Quantity or Callable") - def group_by(self, *line_data_keys: str) -> T: - """ - Group and sum y_data for each spectrum according to the values - mapped to the specified keys in metadata['line_data'] - - Parameters - ---------- - line_data_keys - The key(s) to group by. If only one line_data_key is - supplied, if the value mapped to a key is the same for - multiple spectra, they are placed in the same group and - summed. If multiple line_data_keys are supplied, the values - must be the same for all specified keys for them to be - placed in the same group - - Returns - ------- - grouped_spectrum - A new Spectrum1DCollection with one line for each group. Any - metadata in 'line_data' not common across all spectra in a - group will be discarded - """ - # Remove line_data_keys that are not found in top level of metadata: - # these will not be useful for grouping - keys = [key for key in line_data_keys if key not in self.metadata] - - # If there are no keys left, sum everything as one big group and return - if not keys: - return self.from_spectra([self.sum()]) - - grouping_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*line_data_keys)) - - new_y_data = np.zeros((len(grouping_dict), self._y_data.shape[-1])) - group_metadata = copy.deepcopy(self.metadata) - group_metadata['line_data'] = [{}]*len(grouping_dict) - for i, idxs in enumerate(grouping_dict.values()): - # Look for any common key/values in grouped metadata - group_i_metadata = self._tidy_metadata(idxs) - group_metadata['line_data'][i] = group_i_metadata - new_y_data[i] = np.sum(self._y_data[idxs], axis=0) - new_y_data = new_y_data*ureg(self._internal_y_data_unit).to( - self.y_data_unit) - - new_data = self.copy() - new_data.y_data = new_y_data - new_data.metadata = group_metadata - - return new_data - - def sum(self) -> Spectrum1D: - """ - Sum y_data over all spectra - - Returns - ------- - summed_spectrum - A Spectrum1D created from the summed y_data. Any metadata - in 'line_data' not common across all spectra will be - discarded - """ - metadata = copy.deepcopy(self.metadata) - metadata.pop('line_data', None) - metadata.update(self._tidy_metadata()) - summed_y_data = np.sum(self._y_data, axis=0)*ureg( - self._internal_y_data_unit).to(self.y_data_unit) - return Spectrum1D(np.copy(self.x_data), - summed_y_data, - x_tick_labels=copy.copy(self.x_tick_labels), - metadata=copy.deepcopy(metadata)) - class Spectrum2D(Spectrum): """ From 28a5b1e2ea4841754a3325ee4cde9cc770fb48f9 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 12 Jul 2024 15:44:44 +0100 Subject: [PATCH 04/37] Update Spectrum methods to use Quantity rather than *ureg. Some of this was done in another branch, and this one rebased --- euphonic/spectra.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 4921fb734..9bdebc640 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -565,8 +565,9 @@ def from_castep_phonon_dos(cls: Type[T], filename: str, metadata['species'] = element metadata['label'] = element - return cls(data['dos_bins']*ureg(data['dos_bins_unit']), - data['dos'][element]*ureg(data['dos_unit']), + return cls(ureg.Quantity(data["dos_bins"], + units=data["dos_bins_unit"]), + ureg.Quantity(data["dos"][element], units=data["dos_unit"]), metadata=metadata) @overload @@ -645,7 +646,8 @@ def broaden(self: T, x_width, self.y_data.magnitude, [self.get_bin_centres().magnitude], [x_width.to(self.x_data_unit).magnitude], - shape=shape, method=method) * ureg(self.y_data_unit) + shape=shape, method=method) + y_broadened = ureg.Quantity(y_broadened, units=self.y_data_unit) elif isinstance(x_width, Callable): self.assert_regular_bins(message=( @@ -754,9 +756,10 @@ def sum(self) -> Spectrum: metadata = copy.deepcopy(self.metadata) metadata.pop('line_data', None) metadata.update(self._tidy_metadata()) - summed_s_data = np.sum(self._get_raw_spectrum_data(), axis=0 - ) * ureg(self._get_internal_spectrum_data_unit() - ).to(self._get_spectrum_data_unit()) + summed_s_data = ureg.Quantity( + np.sum(self._get_raw_spectrum_data(), axis=0), + units=self._get_internal_spectrum_data_unit() + ).to(self._get_spectrum_data_unit()) return Spectrum1D( **self._get_bin_kwargs(), **{self._spectrum_data_name(): summed_s_data}, @@ -970,8 +973,10 @@ def group_by(self, *line_data_keys: str) -> Self: group_i_metadata = self._tidy_metadata(idxs) group_metadata['line_data'][i] = group_i_metadata new_s_data[i] = np.sum(self._get_raw_spectrum_data()[idxs], axis=0) - new_s_data = new_s_data*ureg(self._get_internal_spectrum_data_unit()).to( - self._get_spectrum_data_unit()) + + new_s_data = ureg.Quantity(new_s_data, + units=self._get_internal_spectrum_data_unit() + ).to(self._get_spectrum_data_unit()) new_data = self.copy() new_data._set_spectrum_data(new_s_data) @@ -1253,8 +1258,8 @@ def from_castep_phonon_dos(cls: Type[T], filename: str) -> T: metadata['line_data'][i]['species'] = species metadata['line_data'][i]['label'] = species return Spectrum1DCollection( - data['dos_bins']*ureg(data['dos_bins_unit']), - y_data*ureg(data['dos_unit']), + ureg.Quantity(data['dos_bins'], units=data['dos_bins_unit']), + ureg.Quantity(y_data, units=data['dos_unit']), metadata=metadata) @overload @@ -1340,7 +1345,7 @@ def broaden(self: T, method=method) new_spectrum = self.copy() - new_spectrum.y_data = y_broadened * ureg(self.y_data_unit) + new_spectrum.y_data = ureg.Quantity(y_broadened, units=self.y_data_unit) return new_spectrum elif isinstance(x_width, Callable): @@ -1552,7 +1557,7 @@ def broaden(self: T, method=method) spectrum = Spectrum2D(np.copy(self.x_data), np.copy(self.y_data), - z_broadened*ureg(self.z_data_unit), + ureg.Quantity(z_broadened, units=self.z_data_unit), copy.copy(self.x_tick_labels), copy.deepcopy(self.metadata)) else: From a47778f18c0ae7d5fa6c8a80e60c06ce4359c075 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 7 Aug 2024 16:39:01 +0100 Subject: [PATCH 05/37] Pylinting Pylint doesn't like the call to ._set_spectrum_data on another class instance because it doesn't understand that this instance was just created with the current class. I think that means the warning is ok to suppress. --- euphonic/spectra.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 9bdebc640..91803225d 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -771,7 +771,9 @@ def sum(self) -> Spectrum: # Required methods @classmethod @abstractmethod - def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: ... + def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: + """Construct spectrum collection from a sequence of components""" + ... # Mixin methods def __len__(self): @@ -979,7 +981,7 @@ def group_by(self, *line_data_keys: str) -> Self: ).to(self._get_spectrum_data_unit()) new_data = self.copy() - new_data._set_spectrum_data(new_s_data) + new_data._set_spectrum_data(new_s_data) # pylint: disable=W0212 new_data.metadata = group_metadata return new_data From f98c0537558d221e3ae565fe636d9fd200e4c903 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 09:54:57 +0100 Subject: [PATCH 06/37] Reimplement spectrum group_by with toolz, FP style This adds toolz as a dependency to the main package. We don't anticipate that causing a lot of problems; it is a small, stable, pure-python library also available on conda-forge. --- euphonic/spectra.py | 35 +++++-------------- setup.py | 3 +- .../minimum_euphonic_requirements.txt | 1 + 3 files changed, 11 insertions(+), 28 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 91803225d..ce191d880 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -7,6 +7,7 @@ import math import json from numbers import Integral, Real +from toolz.itertoolz import groupby, pluck from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) from typing_extensions import Self @@ -955,36 +956,16 @@ def group_by(self, *line_data_keys: str) -> Self: metadata in 'line_data' not common across all spectra in a group will be discarded """ - # Remove line_data_keys that are not found in top level of metadata: - # these will not be useful for grouping - keys = [key for key in line_data_keys if key not in self.metadata] + get_key_items = lambda enumerated_metadata: tuple( + enumerated_metadata[1].get(item, None) for item in line_data_keys) - # If there are no keys left, sum everything as one big group and return - if not keys: - return self.from_spectra([self.sum()]) + groups = groupby(get_key_items, enumerate(self.iter_metadata())) - grouping_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*line_data_keys)) + indices = lambda enumerated_values: pluck(0, enumerated_values) + sum_over_indices = lambda indices: self[list(indices)].sum() - new_s_data = np.zeros((len(grouping_dict), - *self._get_raw_spectrum_data().shape[1:])) - group_metadata = copy.deepcopy(self.metadata) - group_metadata['line_data'] = [{}]*len(grouping_dict) - for i, idxs in enumerate(grouping_dict.values()): - # Look for any common key/values in grouped metadata - group_i_metadata = self._tidy_metadata(idxs) - group_metadata['line_data'][i] = group_i_metadata - new_s_data[i] = np.sum(self._get_raw_spectrum_data()[idxs], axis=0) - - new_s_data = ureg.Quantity(new_s_data, - units=self._get_internal_spectrum_data_unit() - ).to(self._get_spectrum_data_unit()) - - new_data = self.copy() - new_data._set_spectrum_data(new_s_data) # pylint: disable=W0212 - new_data.metadata = group_metadata - - return new_data + return self.from_spectra([sum_over_indices(indices(group)) + for group in groups.values()]) class Spectrum1DCollection(SpectrumCollectionMixin, diff --git a/setup.py b/setup.py index c7b9250cb..02f28dae1 100644 --- a/setup.py +++ b/setup.py @@ -144,7 +144,8 @@ def run_setup(): 'seekpath>=1.1.0', 'spglib>=1.9.4', 'pint>=0.22', - 'threadpoolctl>=3.0.0' + 'threadpoolctl>=3.0.0', + 'toolz>=0.12.1', ], extras_require={ 'matplotlib': ['matplotlib>=3.8.0'], diff --git a/tests_and_analysis/minimum_euphonic_requirements.txt b/tests_and_analysis/minimum_euphonic_requirements.txt index 065a89b61..dbddec3df 100644 --- a/tests_and_analysis/minimum_euphonic_requirements.txt +++ b/tests_and_analysis/minimum_euphonic_requirements.txt @@ -7,3 +7,4 @@ matplotlib==3.8 h5py==3.6 PyYAML==6.0 threadpoolctl==3.0.0 +toolz==0.12.1 From 5b29747d67d441e9b2b3896e5c00e01648726250 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 10:34:47 +0100 Subject: [PATCH 07/37] Tidy up a bit Linters hate the named lambdas. I think they are quite nice because they are compact and immediately draw attention to the "one-liner" they attach a name to... but maybe the more explicit form with type hints will make it easier for someone to understand in future. --- euphonic/spectra.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index ce191d880..c77b54807 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -8,8 +8,9 @@ import json from numbers import Integral, Real from toolz.itertoolz import groupby, pluck -from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, - Sequence, Tuple, TypeVar, Union, Type) +from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List, + Literal, Optional, overload, Sequence, Tuple, TypeVar, + Union, Type) from typing_extensions import Self import warnings @@ -23,7 +24,6 @@ from euphonic.io import (_obj_to_json_file, _obj_from_json_file, _obj_to_dict, _process_dict) from euphonic.readers.castep import read_phonon_dos_data -from euphonic.util import _get_unique_elems_and_idx from euphonic.validate import _check_constructor_inputs, _check_unit_conversion @@ -768,7 +768,6 @@ def sum(self) -> Spectrum: metadata=metadata ) - # Required methods @classmethod @abstractmethod @@ -851,7 +850,8 @@ def select(self, **select_key_values: Union[ # Collect indices that match each combination of values selected_indices = [] - for value_combination in itertools.product(*select_key_values.values()): + for value_combination in itertools.product(*select_key_values.values() + ): selection = dict(zip(select_key_values.keys(), value_combination)) selected_indices.extend(self._select_indices(**selection)) @@ -956,13 +956,25 @@ def group_by(self, *line_data_keys: str) -> Self: metadata in 'line_data' not common across all spectra in a group will be discarded """ - get_key_items = lambda enumerated_metadata: tuple( - enumerated_metadata[1].get(item, None) for item in line_data_keys) + def get_key_items(enumerated_metadata: tuple[int, OneLineData] + ) -> tuple[str | int, ...]: + """Get sort keys from an item of enumerated input to groupby - groups = groupby(get_key_items, enumerate(self.iter_metadata())) + e.g. with line_data_keys=("a", "b") + + (0, {"a": 4, "d": 5}) --> (4, None) + """ + return tuple(enumerated_metadata[1].get(item, None) + for item in line_data_keys) - indices = lambda enumerated_values: pluck(0, enumerated_values) - sum_over_indices = lambda indices: self[list(indices)].sum() + def indices(enumerated_values: Iterable[tuple[int, Any]] + ) -> Iterator[int]: + return pluck(0, enumerated_values) + + def sum_over_indices(indices: Iterable[int]) -> Self: + return self[list(indices)].sum() + + groups = groupby(get_key_items, enumerate(self.iter_metadata())) return self.from_spectra([sum_over_indices(indices(group)) for group in groups.values()]) From 1159b5c31c504288dae9c02c90d976bfd560b680 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 11:19:42 +0100 Subject: [PATCH 08/37] Standardise import order --- euphonic/spectra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index c77b54807..fdf5210f4 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -7,7 +7,6 @@ import math import json from numbers import Integral, Real -from toolz.itertoolz import groupby, pluck from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) @@ -17,6 +16,7 @@ from pint import DimensionalityError, Quantity import numpy as np from scipy.ndimage import correlate1d, gaussian_filter +from toolz.itertoolz import groupby, pluck from euphonic import ureg, __version__ from euphonic.broadening import (ErrorFit, KernelShape, From e4b46268dbaba1d68bda60ba1342762b780537f9 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 11:25:47 +0100 Subject: [PATCH 09/37] More cleanup: inline one function, replace another with partial --- euphonic/spectra.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index fdf5210f4..976fa3d11 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import collections import copy +from functools import partial import itertools import math import json @@ -967,16 +968,12 @@ def get_key_items(enumerated_metadata: tuple[int, OneLineData] return tuple(enumerated_metadata[1].get(item, None) for item in line_data_keys) - def indices(enumerated_values: Iterable[tuple[int, Any]] - ) -> Iterator[int]: - return pluck(0, enumerated_values) - - def sum_over_indices(indices: Iterable[int]) -> Self: - return self[list(indices)].sum() + # First element of each tuple is the index + indices = partial(pluck, 0) groups = groupby(get_key_items, enumerate(self.iter_metadata())) - return self.from_spectra([sum_over_indices(indices(group)) + return self.from_spectra([self[list(indices(group))].sum() for group in groups.values()]) From 01469f8117a9b1b8503083bc1ecf6d9b7d768dc4 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 11:48:07 +0100 Subject: [PATCH 10/37] Remove some unused private methods/features from spectra --- euphonic/spectra.py | 35 +---------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 976fa3d11..536766956 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -892,49 +892,16 @@ def _combine_metadata(all_metadata: LineData) -> Metadata: return combined_metadata - def _tidy_metadata(self, indices: Optional[Sequence[int]] = None - ) -> Metadata: + def _tidy_metadata(self) -> Metadata: """ For a metadata dictionary, combines all common key/value pairs in 'line_data' and puts them in a top-level dictionary. - If indices is supplied, only those indices in 'line_data' are - combined. Unmatching key/value pairs are discarded """ line_data = self.metadata.get("line_data", [{}] * len(self)) - if indices is not None: - line_data = [line_data[idx] for idx in indices] combined_line_data = self._combine_metadata(line_data) combined_line_data.pop("line_data", None) return combined_line_data - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: - """ - Get value of the key(s) for each element in - metadata['line_data']. Returns a 1D array of tuples, where each - tuple contains the value(s) for each key in line_data_keys, for - a single element in metadata['line_data']. This allows easy - grouping/selecting by specific keys - - For example, if we have a Spectrum1DCollection with the following - metadata: - {'desc': 'Quartz', 'line_data': [ - {'inst': 'LET', 'sample': 0, 'index': 1}, - {'inst': 'MAPS', 'sample': 1, 'index': 2}, - {'inst': 'MARI', 'sample': 1, 'index': 1}, - ]} - Then: - _get_line_data_vals('inst', 'sample') = [('LET', 0), - ('MAPS', 1), - ('MARI', 1)] - - Raises a KeyError if 'line_data' or the key doesn't exist - """ - line_data = self.metadata['line_data'] - line_data_vals = np.empty(len(line_data), dtype=object) - for i, data in enumerate(line_data): - line_data_vals[i] = tuple([data[key] for key in line_data_keys]) - return line_data_vals - def group_by(self, *line_data_keys: str) -> Self: """ Group and sum elements of spectral data according to the values From 44ed62b107ae37790d09f25354be8f979e35eb7a Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 11:53:16 +0100 Subject: [PATCH 11/37] Drop unused import --- euphonic/spectra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 536766956..f20c194cb 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -8,7 +8,7 @@ import math import json from numbers import Integral, Real -from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List, +from typing import (Any, Callable, Dict, Generator, Iterator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) from typing_extensions import Self From a7ce9dbb1fe879cf0e73952aa72f460532f4d5be Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 14:08:05 +0100 Subject: [PATCH 12/37] Refactor __getitem__ and move into SpectrumCollectionMixin This structure seems a bit more legible and should reduce redundancy in Spectrum2DCollection --- euphonic/spectra.py | 118 ++++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index f20c194cb..6b6c27be4 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -8,9 +8,9 @@ import math import json from numbers import Integral, Real -from typing import (Any, Callable, Dict, Generator, Iterator, List, - Literal, Optional, overload, Sequence, Tuple, TypeVar, - Union, Type) +from operator import itemgetter +from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, + overload, Sequence, Tuple, TypeVar, Union, Type) from typing_extensions import Self import warnings @@ -780,6 +780,71 @@ def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: def __len__(self): return self._get_raw_spectrum_data().shape[0] + @overload + def __getitem__(self, item: int) -> Spectrum1D: + ... + + @overload # noqa: F811 + def __getitem__(self, item: slice) -> Self: + ... + + @overload # noqa: F811 + def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> Self: + ... + + def __getitem__( + self, item: Union[Integral, slice, Sequence[Integral], np.ndarray] + ): # noqa: F811 + self._validate_item(item) + init_kwargs = { + self._spectrum_data_name(): self._get_spectrum_data()[item, :], + "x_tick_labels": self.x_tick_labels, + "metadata": self._get_item_metadata(item) + } | self._get_bin_kwargs() + + if isinstance(item, Integral): + return self._item_type(**init_kwargs) + + return type(self)(**init_kwargs) + + def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarray + ) -> None: + """Raise Error if index has inappropriate typing/range""" + if isinstance(item, Integral): + return + if isinstance(item, slice): + if (item.stop is not None) and (item.stop >= len(self)): + raise IndexError(f'index "{item.stop}" out of range') + return + + if not all([isinstance(i, Integral) for i in item]): + raise TypeError( + f'Index "{item}" should be an integer, slice ' + f'or sequence of ints') + + @overload + def _get_item_metadata(self, item: Integral) -> OneLineData: + """Get a single metadata item with no line_data""" + + @overload + def _get_item_metadata(self, item: slice | Sequence[Integral] | np.ndarray + ) -> Metadata: # noqa: F811 + """Get a metadata collection (may include line_data)""" + + def _get_item_metadata(self, item): # noqa: F811 + """Produce appropriate metadata for __getitem__""" + metadata_lines = list(self.iter_metadata()) + + if isinstance(item, Integral): + return metadata_lines[item] + elif isinstance(item, slice): + return self._combine_metadata(metadata_lines[item]) + elif len(item) == 1: + return metadata_lines[item[0]] + else: + return self._combine_metadata( + list(itemgetter(*item)(metadata_lines))) + def copy(self) -> Self: """Get an independent copy of spectrum""" return self._item_type.copy(self) @@ -803,11 +868,11 @@ def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" common_metadata = dict( (key, self.metadata[key]) - for key in self.metadata.keys() - set("line_data")) + for key in set(self.metadata.keys()) - {"line_data",}) line_data = self.metadata.get("line_data") if line_data is None: - line_data = itertools.repeat({}, len(self._z_data)) + line_data = itertools.repeat({}, len(self._get_raw_spectrum_data())) for one_line_data in line_data: yield common_metadata | one_line_data @@ -1048,49 +1113,6 @@ def _split_by_indices(self, metadata=self.metadata) for x0, x1 in ranges] - @overload - def __getitem__(self, item: int) -> Spectrum1D: - ... - - @overload # noqa: F811 - def __getitem__(self, item: slice) -> Self: - ... - - @overload # noqa: F811 - def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> Self: - ... - - def __getitem__(self, item: Union[int, slice, Sequence[int], np.ndarray] - ): # noqa: F811 - new_metadata = copy.deepcopy(self.metadata) - line_metadata = new_metadata.pop('line_data', - [{} for _ in self._y_data]) - if isinstance(item, Integral): - new_metadata.update(line_metadata[item]) - return Spectrum1D(self.x_data, - self.y_data[item, :], - x_tick_labels=self.x_tick_labels, - metadata=new_metadata) - - if isinstance(item, slice): - if (item.stop is not None) and (item.stop >= len(self)): - raise IndexError(f'index "{item.stop}" out of range') - new_metadata.update(self._combine_metadata(line_metadata[item])) - else: - try: - item = list(item) - if not all([isinstance(i, Integral) for i in item]): - raise TypeError - except TypeError: - raise TypeError(f'Index "{item}" should be an integer, slice ' - f'or sequence of ints') - new_metadata.update(self._combine_metadata( - [line_metadata[i] for i in item])) - return type(self)(self.x_data, - self.y_data[item, :], - x_tick_labels=self.x_tick_labels, - metadata=new_metadata) - @classmethod def from_spectra(cls: Type[T], spectra: Sequence[Spectrum1D]) -> T: if len(spectra) < 1: From 0e05e0e2083ce28f1d8a119beab5cadfa123f8dd Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 14:19:16 +0100 Subject: [PATCH 13/37] Tidying: indentation, more pedantic typing --- euphonic/spectra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 6b6c27be4..26536ce10 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -781,7 +781,7 @@ def __len__(self): return self._get_raw_spectrum_data().shape[0] @overload - def __getitem__(self, item: int) -> Spectrum1D: + def __getitem__(self, item: int) -> Spectrum: ... @overload # noqa: F811 @@ -836,7 +836,7 @@ def _get_item_metadata(self, item): # noqa: F811 metadata_lines = list(self.iter_metadata()) if isinstance(item, Integral): - return metadata_lines[item] + return metadata_lines[item] elif isinstance(item, slice): return self._combine_metadata(metadata_lines[item]) elif len(item) == 1: From 79ff12772613098da7668149fc99d3cffb0ad2f1 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 14:20:50 +0100 Subject: [PATCH 14/37] More linting --- euphonic/spectra.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 26536ce10..6d6481f8b 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -816,7 +816,7 @@ def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarra if (item.stop is not None) and (item.stop >= len(self)): raise IndexError(f'index "{item.stop}" out of range') return - + if not all([isinstance(i, Integral) for i in item]): raise TypeError( f'Index "{item}" should be an integer, slice ' @@ -837,13 +837,12 @@ def _get_item_metadata(self, item): # noqa: F811 if isinstance(item, Integral): return metadata_lines[item] - elif isinstance(item, slice): + if isinstance(item, slice): return self._combine_metadata(metadata_lines[item]) - elif len(item) == 1: + if len(item) == 1: return metadata_lines[item[0]] - else: - return self._combine_metadata( - list(itemgetter(*item)(metadata_lines))) + return self._combine_metadata( + list(itemgetter(*item)(metadata_lines))) def copy(self) -> Self: """Get an independent copy of spectrum""" From a041d72a67faccd302c8dcb2401862a7e7d426b1 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 14:41:47 +0100 Subject: [PATCH 15/37] Generalise more Spectrum1DCollection methods to mixin class --- euphonic/spectra.py | 104 +++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 6d6481f8b..66fcb0efb 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -710,11 +710,13 @@ class SpectrumCollectionMixin(ABC): _item_type = Spectrum1D # Define some private methods which wrap this information into useful forms - def _spectrum_data_name(self) -> str: - return f"{self._spectrum_axis}_data" + @classmethod + def _spectrum_data_name(cls) -> str: + return f"{cls._spectrum_axis}_data" - def _spectrum_raw_data_name(self) -> str: - return f"_{self._spectrum_axis}_data" + @classmethod + def _spectrum_raw_data_name(cls) -> str: + return f"_{cls._spectrum_axis}_data" def _get_spectrum_data(self) -> Quantity: return getattr(self, self._spectrum_data_name()) @@ -825,7 +827,7 @@ def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarra @overload def _get_item_metadata(self, item: Integral) -> OneLineData: """Get a single metadata item with no line_data""" - + @overload def _get_item_metadata(self, item: slice | Sequence[Integral] | np.ndarray ) -> Metadata: # noqa: F811 @@ -1008,6 +1010,58 @@ def get_key_items(enumerated_metadata: tuple[int, OneLineData] for group in groups.values()]) + def to_dict(self) -> Dict[str, Any]: + """ + Convert to a dictionary consistent with from_dict() + + Returns + ------- + dict + """ + attrs = [*self._get_bin_kwargs().keys(), + self._spectrum_data_name(), + 'x_tick_labels', + 'metadata'] + + return _obj_to_dict(self, attrs) + + @classmethod + def from_dict(cls: Self, d: dict) -> Self: + """ + Convert a dictionary to a Spectrum Collection object + + Parameters + ---------- + d : dict + A dictionary with the following keys/values: + + - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray + - 'x_data_unit': str + - 'y_data': (n_x_data,) float ndarray + - 'y_data_unit': str + + There are also the following optional keys: + + - 'x_tick_labels': list of (int, string) tuples + - 'metadata': dict + + Returns + ------- + spectrum_collection + """ + data_keys = list(f"{dim}_data" for dim in cls._bin_axes) + data_keys.append(cls._spectrum_data_name()) + + d = _process_dict(d, + quantities=data_keys, + optional=['x_tick_labels', 'metadata']) + + data_args = [d[key] for key in data_keys] + return cls(*data_args, + x_tick_labels=d['x_tick_labels'], + metadata=d['metadata']) + + class Spectrum1DCollection(SpectrumCollectionMixin, Spectrum, collections.abc.Sequence): @@ -1143,17 +1197,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - def to_dict(self) -> Dict[str, Any]: - """ - Convert to a dictionary consistent with from_dict() - - Returns - ------- - dict - """ - return _obj_to_dict(self, ['x_data', 'y_data', 'x_tick_labels', - 'metadata']) - def to_text_file(self, filename: str, fmt: Optional[Union[str, Sequence[str]]] = None) -> None: """ @@ -1187,35 +1230,6 @@ def to_text_file(self, filename: str, kwargs['fmt'] = fmt np.savetxt(filename, out_data, **kwargs) - @classmethod - def from_dict(cls: Type[T], d) -> T: - """ - Convert a dictionary to a Spectrum1DCollection object - - Parameters - ---------- - d : dict - A dictionary with the following keys/values: - - - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray - - 'x_data_unit': str - - 'y_data': (n_x_data,) float ndarray - - 'y_data_unit': str - - There are also the following optional keys: - - - 'x_tick_labels': list of (int, string) tuples - - 'metadata': dict - - Returns - ------- - spectrum_collection - """ - d = _process_dict(d, quantities=['x_data', 'y_data'], - optional=['x_tick_labels', 'metadata']) - return cls(d['x_data'], d['y_data'], x_tick_labels=d['x_tick_labels'], - metadata=d['metadata']) - @classmethod def from_castep_phonon_dos(cls: Type[T], filename: str) -> T: """ From 35544a180bfe50a7152adaab29e95ce0d8e16b33 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 8 Aug 2024 14:46:33 +0100 Subject: [PATCH 16/37] Move detailed from_dict docstring to child class --- euphonic/spectra.py | 49 ++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 66fcb0efb..db375009f 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1009,7 +1009,6 @@ def get_key_items(enumerated_metadata: tuple[int, OneLineData] return self.from_spectra([self[list(indices(group))].sum() for group in groups.values()]) - def to_dict(self) -> Dict[str, Any]: """ Convert to a dictionary consistent with from_dict() @@ -1027,28 +1026,7 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls: Self, d: dict) -> Self: - """ - Convert a dictionary to a Spectrum Collection object - - Parameters - ---------- - d : dict - A dictionary with the following keys/values: - - - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray - - 'x_data_unit': str - - 'y_data': (n_x_data,) float ndarray - - 'y_data_unit': str - - There are also the following optional keys: - - - 'x_tick_labels': list of (int, string) tuples - - 'metadata': dict - - Returns - ------- - spectrum_collection - """ + """Initialise a Spectrum Collection object from dict""" data_keys = list(f"{dim}_data" for dim in cls._bin_axes) data_keys.append(cls._spectrum_data_name()) @@ -1357,6 +1335,31 @@ def broaden(self: T, else: raise TypeError("x_width must be a Quantity or Callable") + @classmethod + def from_dict(cls: Self, d: dict) -> Self: + """ + Convert a dictionary to a Spectrum Collection object + + Parameters + ---------- + d : dict + A dictionary with the following keys/values: + + - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray + - 'x_data_unit': str + - 'y_data': (n_x_data,) float ndarray + - 'y_data_unit': str + + There are also the following optional keys: + + - 'x_tick_labels': list of (int, string) tuples + - 'metadata': dict + + Returns + ------- + spectrum_collection + """ + return super().from_dict(d) class Spectrum2D(Spectrum): """ From c445fb90d92405ac4b0f5cacde674a068d2a66e6 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 9 Aug 2024 10:13:18 +0100 Subject: [PATCH 17/37] Begin implementing Spectrum2DCollection --- euphonic/spectra.py | 91 ++++++++++++++++++- .../test_spectrum2dcollection.py | 24 +++++ 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py diff --git a/euphonic/spectra.py b/euphonic/spectra.py index db375009f..f9ffcb2db 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1391,7 +1391,7 @@ class Spectrum2D(Spectrum): def __init__(self, x_data: Quantity, y_data: Quantity, z_data: Quantity, x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, - metadata: Optional[Dict[str, Union[int, str]]] = None + metadata: Optional[Metadata] = None ) -> None: """ Parameters @@ -1768,6 +1768,95 @@ def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: metadata=d['metadata']) +class Spectrum2DCollection(SpectrumCollectionMixin, + Spectrum, + collections.abc.Sequence): + """A collection of Spectrum2D with common x_data, y_data and x_tick_labels + + Intended for convenient storage of contributions to spectral maps such as + S(Q,w). This object can be indexed or iterated to obtain individual + Spectrum2D. + + Attributes + ---------- + x_data + Shape (n_x_data,) or (n_x_data + 1,) float Quantity. The x_data + points (if size == (n_x_data,)) or x_data bin edges (if size + == (n_x_data + 1,)) + y_data + Shape (n_y_data,) or (n_y_data + 1,) float Quantity. The y_data + points (if size == (n_y_data,)) or y_data bin edges (if size + == (n_y_data + 1,)) + z_data + Shape (n_entries, n_x_data, n_y_data) float Quantity. The spectral data + in x and y, indexed over components + x_tick_labels + Sequence[Tuple[int, str]] or None. Special tick labels e.g. for + high-symmetry points. The int refers to the index in x_data the + label should be applied to + metadata + Dict[str, Union[int, str, LineData]] or None. Contains metadata + about the spectra. Keys should be strings and values should be + strings or integers. + There are some functional keys: + + - 'line_data' : LineData + This is a Sequence[Dict[str, Union[int, str]], + it contains metadata for each spectrum in + the collection, and must be of length + n_entries + """ + + # Private attributes used by SpectrumCollectionMixin + _spectrum_axis = "z" + _item_type = Spectrum2D + + def __init__( + self, x_data: Quantity, y_data: Quantity, z_data: Quantity, + x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + metadata: Optional[Metadata] = None + ) -> None: + _check_constructor_inputs( + [z_data, x_tick_labels, metadata], + [Quantity, [list, type(None)], [dict, type(None)]], + [(-1, -1, -1), (), ()], + ['z_data', 'x_tick_labels', 'metadata']) + nx = z_data.shape[1] + ny = z_data.shape[2] + _check_constructor_inputs( + [x_data, y_data], + [Quantity, Quantity], + [[(nx,), (nx + 1,)], [(ny,), (ny + 1,)]], + ['x_data', 'y_data']) + + self._set_data(x_data, 'x') + self._set_data(y_data, 'y') + self.x_tick_labels = x_tick_labels + self._set_data(z_data, 'z') + if metadata and 'line_data' in metadata.keys(): + if len(metadata['line_data']) != len(z_data): + raise ValueError( + f'z_data contains {len(z_data)} spectra, but ' + f'metadata["line_data"] contains ' + f'{len(metadata["line_data"])} entries') + self.metadata = {} if metadata is None else metadata + + def _split_by_indices(self, indices: Sequence[int] | np.ndarray + ) -> List[Self]: + """Split data along x axis at given indices""" + ranges = self._ranges_from_indices(indices) + return [type(self)(self.x_data[x0:x1], + self.y_data, + self.z_data[:, x0:x1, :], + x_tick_labels=self._cut_x_ticks( + self.x_tick_labels, x0, x1), + metadata=self.metadata) + for x0, x1 in ranges] + + @classmethod + def from_spectra(cls, spectra: Sequence[Spectrum2D]) -> Self: + raise NotImplementedError() + def apply_kinematic_constraints(spectrum: Spectrum2D, e_i: Quantity = None, e_f: Quantity = None, diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py new file mode 100644 index 000000000..c395675a1 --- /dev/null +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -0,0 +1,24 @@ +import numpy as np + +from euphonic import ureg +from euphonic.spectra import Spectrum2DCollection + +class TestSpectrum2DCollectionCreation: + def test_init_from_numbers(self): + N_X = 10 + N_Y = 20 + N_Z = 5 + + x_data = ureg.Quantity(np.linspace(0, 100, N_X), "1 / angstrom") + y_data = ureg.Quantity(np.linspace(0, 2000, N_Y), "meV") + z_data = ureg.Quantity(np.random.random((N_Z, N_X, N_Y)), "1 / meV") + + metadata = {"flavour": "chocolate", + "line_data": [{"index": i} for i in range(N_Z)]} + + x_tick_labels = [(0, "Start"), (N_X - 1, "END")] + + spectrum = Spectrum2DCollection( + x_data, y_data, z_data, + x_tick_labels=x_tick_labels, metadata=metadata) + From e9bff3d65e1d08042e89073c6780fc4818d8cf08 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 4 Sep 2024 09:36:24 +0100 Subject: [PATCH 18/37] Spectrum2DCollection; initial implementation and rough tests --- euphonic/spectra.py | 44 ++++++++++++++++- .../test_spectrum2dcollection.py | 49 ++++++++++++++++++- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index f9ffcb2db..ae3351600 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1808,6 +1808,7 @@ class Spectrum2DCollection(SpectrumCollectionMixin, """ # Private attributes used by SpectrumCollectionMixin + _bin_axes = ("x", "y") _spectrum_axis = "z" _item_type = Spectrum2D @@ -1855,7 +1856,48 @@ def _split_by_indices(self, indices: Sequence[int] | np.ndarray @classmethod def from_spectra(cls, spectra: Sequence[Spectrum2D]) -> Self: - raise NotImplementedError() + if len(spectra) < 1: + raise IndexError("At least one spectrum is needed for collection") + + def _type_check(spectrum): + if not isinstance(spectrum, Spectrum2D): + raise TypeError( + "from_spectra() requires a sequence of Spectrum2D") + + _type_check(spectra[0]) + bins_data = { + f"{ax}_data": getattr(spectra[0], f"{ax}_data") + for ax in cls._bin_axes + } + x_tick_labels = spectra[0].x_tick_labels + + spectrum_0_data = getattr(spectra[0], f"{cls._spectrum_axis}_data") + spectrum_data_shape = spectrum_0_data.shape + spectrum_data_magnitude = np.empty((len(spectra), *spectrum_data_shape)) + spectrum_data_magnitude[0, :, :] = spectrum_0_data.magnitude + spectrum_data_units = spectrum_0_data.units + + for i, spectrum in enumerate(spectra[1:]): + _type_check(spectrum) + spectrum_i_data = getattr(spectrum, f"_{cls._spectrum_axis}_data") + spectrum_i_data_units = getattr(spectrum, f"{cls._spectrum_axis}_data_unit") + assert (spectrum_i_data_units == spectrum_data_units) + + for key, ref_bins in bins_data.items(): + item_bins = getattr(spectrum, key) + assert np.allclose(item_bins.magnitude, ref_bins.magnitude) + assert item_bins.units == ref_bins.units + + assert spectrum.x_tick_labels == x_tick_labels + spectrum_data_magnitude[i + 1, :, :] = spectrum_i_data + + metadata = cls._combine_metadata([spec.metadata for spec in spectra]) + spectrum_data = Quantity(spectrum_data_magnitude, spectrum_data_units) + return cls(**bins_data, + **{f"{cls._spectrum_axis}_data": spectrum_data}, + x_tick_labels=x_tick_labels, + metadata=metadata) + def apply_kinematic_constraints(spectrum: Spectrum2D, e_i: Quantity = None, diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index c395675a1..bfc46d37e 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -1,7 +1,41 @@ +from typing import Optional + import numpy as np +import pytest + +from euphonic import Quantity, ureg +from euphonic.spectra import OneLineData, Spectrum2D, Spectrum2DCollection + +from .test_spectrum2d import check_spectrum2d + +# def check_spectrum2d(actual_spectrum2d, expected_spectrum2d, equal_nan=False, +# z_atol=np.finfo(np.float64).eps): + + +def rand_spectrum2d(seed: int = 1, + x_bins: Optional[Quantity] = None, + y_bins: Optional[Quantity] = None, + metadata: Optional[OneLineData] = None) -> Spectrum2D: + rng = np.random.default_rng(seed=seed) + + if x_bins is None: + x_bins = np.linspace(*sorted([rng.random(), rng.random()]), + rng.integers(3, 10), + ) * ureg("1 / angstrom") + if y_bins is None: + y_bins = np.linspace(*sorted([rng.random(), rng.random()]), + rng.integers(3, 10)) * ureg("meV") + if metadata is None: + metadata = {"index": rng.integers(10), + "value": rng.random(), + "tag": "common"} + + spectrum = Spectrum2D(x_data=x_bins, + y_data=y_bins, + z_data=rng.random([len(x_bins) - 1, len(y_bins) - 1]) * ureg("millibarn / meV"), + metadata=metadata) + return spectrum -from euphonic import ureg -from euphonic.spectra import Spectrum2DCollection class TestSpectrum2DCollectionCreation: def test_init_from_numbers(self): @@ -22,3 +56,14 @@ def test_init_from_numbers(self): x_data, y_data, z_data, x_tick_labels=x_tick_labels, metadata=metadata) + def test_init_from_spectra(self): + spec_2d = rand_spectrum2d(seed=1) + spec_2d_consistent = rand_spectrum2d().copy() + spec_2d_consistent._z_data *= 2 + spec_2d.metadata["index"] = 2 + + spectrum = Spectrum2DCollection.from_spectra([spec_2d, spec_2d_consistent]) + + spec_2d_inconsistent = rand_spectrum2d(seed=2) + with pytest.raises(ValueError): + spectrum = Spectrum2DCollection.from_spectra([spec_2d, spec_2d_inconsistent]) From ae3648b3c076952ed3902aa9b2b06daafd2f545c Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 16 Sep 2024 16:50:02 +0100 Subject: [PATCH 19/37] Linting --- .../euphonic_test/test_spectrum2dcollection.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index bfc46d37e..317ddd8b5 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -6,8 +6,6 @@ from euphonic import Quantity, ureg from euphonic.spectra import OneLineData, Spectrum2D, Spectrum2DCollection -from .test_spectrum2d import check_spectrum2d - # def check_spectrum2d(actual_spectrum2d, expected_spectrum2d, equal_nan=False, # z_atol=np.finfo(np.float64).eps): @@ -32,7 +30,8 @@ def rand_spectrum2d(seed: int = 1, spectrum = Spectrum2D(x_data=x_bins, y_data=y_bins, - z_data=rng.random([len(x_bins) - 1, len(y_bins) - 1]) * ureg("millibarn / meV"), + z_data=rng.random([len(x_bins) - 1, len(y_bins) - 1] + ) * ureg("millibarn / meV"), metadata=metadata) return spectrum @@ -56,14 +55,19 @@ def test_init_from_numbers(self): x_data, y_data, z_data, x_tick_labels=x_tick_labels, metadata=metadata) + assert spectrum + def test_init_from_spectra(self): spec_2d = rand_spectrum2d(seed=1) spec_2d_consistent = rand_spectrum2d().copy() spec_2d_consistent._z_data *= 2 spec_2d.metadata["index"] = 2 - spectrum = Spectrum2DCollection.from_spectra([spec_2d, spec_2d_consistent]) + spectrum = Spectrum2DCollection.from_spectra( + [spec_2d, spec_2d_consistent]) spec_2d_inconsistent = rand_spectrum2d(seed=2) with pytest.raises(ValueError): - spectrum = Spectrum2DCollection.from_spectra([spec_2d, spec_2d_inconsistent]) + spectrum = Spectrum2DCollection.from_spectra( + [spec_2d, spec_2d_inconsistent]) + assert spectrum From 4659cad5bfa26a9b3fd22ecf9d7d0c70017a0a06 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 16 Sep 2024 16:52:32 +0100 Subject: [PATCH 20/37] More linting --- .../test/euphonic_test/test_spectrum2dcollection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index 317ddd8b5..f5877f8a5 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -6,14 +6,12 @@ from euphonic import Quantity, ureg from euphonic.spectra import OneLineData, Spectrum2D, Spectrum2DCollection -# def check_spectrum2d(actual_spectrum2d, expected_spectrum2d, equal_nan=False, -# z_atol=np.finfo(np.float64).eps): - def rand_spectrum2d(seed: int = 1, x_bins: Optional[Quantity] = None, y_bins: Optional[Quantity] = None, metadata: Optional[OneLineData] = None) -> Spectrum2D: + """Generate a Spectrum2D with random axis lengths, ranges, and metadata""" rng = np.random.default_rng(seed=seed) if x_bins is None: @@ -38,6 +36,7 @@ def rand_spectrum2d(seed: int = 1, class TestSpectrum2DCollectionCreation: def test_init_from_numbers(self): + """Construct Spectrum2DCollection with __init__()""" N_X = 10 N_Y = 20 N_Z = 5 @@ -58,6 +57,7 @@ def test_init_from_numbers(self): assert spectrum def test_init_from_spectra(self): + """Construct collection from a series of Spectrum2D""" spec_2d = rand_spectrum2d(seed=1) spec_2d_consistent = rand_spectrum2d().copy() spec_2d_consistent._z_data *= 2 From e32aa456eec04da7282a9e93a0e6a6f9d6a9edf0 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 16 Sep 2024 17:03:26 +0100 Subject: [PATCH 21/37] Refactor spectrum item data access to methods Collect the axis-twiddling code in one place to improve readability. --- euphonic/spectra.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index ae3351600..0913ad0bf 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -746,6 +746,18 @@ def _get_bin_kwargs(self) -> Dict[str, Quantity]: return {f"{axis}_data": getattr(self, f"{axis}_data") for axis in self._bin_axes} + @classmethod + def _get_item_data(cls, item: Spectrum) -> Quantity: + return getattr(item, f"{cls._spectrum_axis}_data") + + @classmethod + def _get_item_raw_data(cls, item: Spectrum) -> np.ndarray: + return getattr(item, f"_{cls._spectrum_axis}_data") + + @classmethod + def _get_item_data_unit(cls, item: Spectrum) -> str: + return getattr(item, f"{cls._spectrum_axis}_data_unit") + def sum(self) -> Spectrum: """ Sum collection to a single spectrum @@ -1871,16 +1883,17 @@ def _type_check(spectrum): } x_tick_labels = spectra[0].x_tick_labels - spectrum_0_data = getattr(spectra[0], f"{cls._spectrum_axis}_data") + spectrum_0_data = cls._get_item_data(spectra[0]) spectrum_data_shape = spectrum_0_data.shape - spectrum_data_magnitude = np.empty((len(spectra), *spectrum_data_shape)) + spectrum_data_magnitude = np.empty( + (len(spectra), *spectrum_data_shape)) spectrum_data_magnitude[0, :, :] = spectrum_0_data.magnitude spectrum_data_units = spectrum_0_data.units for i, spectrum in enumerate(spectra[1:]): _type_check(spectrum) - spectrum_i_data = getattr(spectrum, f"_{cls._spectrum_axis}_data") - spectrum_i_data_units = getattr(spectrum, f"{cls._spectrum_axis}_data_unit") + spectrum_i_raw_data = cls._get_item_raw_data(spectrum) + spectrum_i_data_units = cls._get_item_data_unit(spectrum) assert (spectrum_i_data_units == spectrum_data_units) for key, ref_bins in bins_data.items(): @@ -1889,7 +1902,7 @@ def _type_check(spectrum): assert item_bins.units == ref_bins.units assert spectrum.x_tick_labels == x_tick_labels - spectrum_data_magnitude[i + 1, :, :] = spectrum_i_data + spectrum_data_magnitude[i + 1, :, :] = spectrum_i_raw_data metadata = cls._combine_metadata([spec.metadata for spec in spectra]) spectrum_data = Quantity(spectrum_data_magnitude, spectrum_data_units) From fb87b7dd144dff6e3962375cc6577cc0065f8790 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Tue, 17 Sep 2024 11:01:55 +0100 Subject: [PATCH 22/37] Implement Spectrum2DCollection z_data; test slicing with ref data --- euphonic/spectra.py | 11 + .../data/spectrum2d/quartz_fuzzy_map_0.json | 267 +++++++ .../data/spectrum2d/quartz_fuzzy_map_1.json | 267 +++++++ .../data/spectrum2d/quartz_fuzzy_map_2.json | 267 +++++++ .../quartz_fuzzy_map.json | 703 ++++++++++++++++++ .../test_spectrum2dcollection.py | 55 ++ 6 files changed, 1570 insertions(+) create mode 100644 tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json create mode 100644 tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json create mode 100644 tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json create mode 100644 tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 0913ad0bf..66219fd0b 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1866,6 +1866,17 @@ def _split_by_indices(self, indices: Sequence[int] | np.ndarray metadata=self.metadata) for x0, x1 in ranges] + @property + def z_data(self) -> Quantity: + return ureg.Quantity( + self._z_data, self._internal_z_data_unit + ).to(self.z_data_unit, "reciprocal_spectroscopy") + + @z_data.setter + def z_data(self, value: Quantity) -> None: + self.z_data_unit = str(value.units) + self._z_data = value.to(self._internal_z_data_unit).magnitude + @classmethod def from_spectra(cls, spectra: Sequence[Spectrum2D]) -> Self: if len(spectra) < 1: diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json new file mode 100644 index 000000000..015da5952 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 0 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json new file mode 100644 index 000000000..d2c5fd013 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 1 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555557, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.06333333333333332, + 0.0, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json new file mode 100644 index 000000000..12106d100 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 2 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555557, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.06333333333333335 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.06333333333333332, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.042222222222222223, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json b/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json new file mode 100644 index 000000000..2240424be --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json @@ -0,0 +1,703 @@ +{ + "__euphonic_class__": "Spectrum2DCollection", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "line_data": [ + { + "direction": 0 + }, + { + "direction": 1 + }, + { + "direction": 2 + } + ] + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ] + ], + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.06333333333333332, + 0.0, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ] + ], + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.06333333333333335 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.06333333333333332, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.042222222222222223, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ] + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index f5877f8a5..acc5c46e4 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -6,6 +6,18 @@ from euphonic import Quantity, ureg from euphonic.spectra import OneLineData, Spectrum2D, Spectrum2DCollection +from tests_and_analysis.test.utils import get_data_path +from .test_spectrum2d import check_spectrum2d, get_spectrum2d + + +def get_spectrum2dcollection_path(*subpaths): + return get_data_path('spectrum2dcollection', *subpaths) + + +def get_spectrum2dcollection(json_filename): + return Spectrum2DCollection.from_json_file( + get_spectrum2dcollection_path(json_filename)) + def rand_spectrum2d(seed: int = 1, x_bins: Optional[Quantity] = None, @@ -71,3 +83,46 @@ def test_init_from_spectra(self): spectrum = Spectrum2DCollection.from_spectra( [spec_2d, spec_2d_inconsistent]) assert spectrum + + def test_from_spectra(self): + spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") + for i in range(3)] + collection = Spectrum2DCollection.from_spectra(spectra) + + ref_collection = get_spectrum2dcollection("quartz_fuzzy_map.json") + + for attr in ("x_data", "y_data", "z_data"): + new, ref = getattr(collection, attr), getattr(ref_collection, attr) + assert new.units == ref.units + np.testing.assert_allclose(new, ref) + + if ref_collection.metadata is None: + assert collection.metadata is None + else: + assert ref_collection.metadata == collection.metadata + + def test_indexing(self): + """Check indexing an element, slice and iteration + + - Individual index should yield corresponding Spectrum2D + - A slice should yield a new Spectrum2DCollection + - Iteration should yield a series of Spectrum2D + + """ + # TODO move spectrum load to a common fixture + + spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") + for i in range(3)] + collection = get_spectrum2dcollection("quartz_fuzzy_map.json") + + item_1 = collection[1] + assert isinstance(item_1, Spectrum2D) + check_spectrum2d(item_1, spectra[1]) + + item_1_to_end = collection[1:] + assert isinstance(item_1_to_end, Spectrum2DCollection) + assert item_1_to_end != collection + + for item, ref in zip(item_1_to_end, spectra[1:]): + assert isinstance(item, Spectrum2D) + check_spectrum2d(item, ref) From 6235aabb7121dfa7c2ba2271cd7364761eb211b4 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Tue, 17 Sep 2024 16:31:58 +0100 Subject: [PATCH 23/37] Spectrum2DCollection; more testing, fix sum() - Refactor re-used data import to use pytest fixture - Remove initial from_spectra test; new one covers it all - Test from_spectra with inconsistent input - Test mixin-supplied methods --- euphonic/spectra.py | 2 +- .../test_spectrum2dcollection.py | 133 +++++++++++++----- 2 files changed, 100 insertions(+), 35 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 66219fd0b..869d3917e 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -776,7 +776,7 @@ def sum(self) -> Spectrum: np.sum(self._get_raw_spectrum_data(), axis=0), units=self._get_internal_spectrum_data_unit() ).to(self._get_spectrum_data_unit()) - return Spectrum1D( + return self._item_type( **self._get_bin_kwargs(), **{self._spectrum_data_name(): summed_s_data}, x_tick_labels=copy.copy(self.x_tick_labels), diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index acc5c46e4..99218d497 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -19,6 +19,40 @@ def get_spectrum2dcollection(json_filename): get_spectrum2dcollection_path(json_filename)) +@pytest.fixture +def quartz_fuzzy_collection(): + return get_spectrum2dcollection("quartz_fuzzy_map.json") + + +@pytest.fixture +def quartz_fuzzy_items(): + return [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") for i in range(3)] + +@pytest.fixture +def inconsistent_x_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._x_data *= 2. + return item + +@pytest.fixture +def inconsistent_x_units_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item.x_data_unit = "1/bohr" + return item + +@pytest.fixture +def inconsistent_x_length_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._x_data = item._x_data[:-2] + item._z_data = item._z_data[:-2, :] + return item + +@pytest.fixture +def inconsistent_y_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._y_data *= 2. + return item + def rand_spectrum2d(seed: int = 1, x_bins: Optional[Quantity] = None, y_bins: Optional[Quantity] = None, @@ -66,30 +100,17 @@ def test_init_from_numbers(self): x_data, y_data, z_data, x_tick_labels=x_tick_labels, metadata=metadata) - assert spectrum - - def test_init_from_spectra(self): - """Construct collection from a series of Spectrum2D""" - spec_2d = rand_spectrum2d(seed=1) - spec_2d_consistent = rand_spectrum2d().copy() - spec_2d_consistent._z_data *= 2 - spec_2d.metadata["index"] = 2 - - spectrum = Spectrum2DCollection.from_spectra( - [spec_2d, spec_2d_consistent]) - - spec_2d_inconsistent = rand_spectrum2d(seed=2) - with pytest.raises(ValueError): - spectrum = Spectrum2DCollection.from_spectra( - [spec_2d, spec_2d_inconsistent]) - assert spectrum + for attr, data in [("x_data", x_data), + ("y_data", y_data), + ("z_data", z_data)]: + np.testing.assert_allclose(getattr(spectrum, attr), data) - def test_from_spectra(self): - spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") - for i in range(3)] - collection = Spectrum2DCollection.from_spectra(spectra) + assert spectrum.metadata == metadata - ref_collection = get_spectrum2dcollection("quartz_fuzzy_map.json") + def test_from_spectra(self, quartz_fuzzy_collection, quartz_fuzzy_items): + """Use alternate constructor Spectrum2DCollection.from_spectra()""" + collection = Spectrum2DCollection.from_spectra(quartz_fuzzy_items) + ref_collection = quartz_fuzzy_collection for attr in ("x_data", "y_data", "z_data"): new, ref = getattr(collection, attr), getattr(ref_collection, attr) @@ -101,7 +122,36 @@ def test_from_spectra(self): else: assert ref_collection.metadata == collection.metadata - def test_indexing(self): + def test_from_bad_spectra( + self, + quartz_fuzzy_items, + inconsistent_x_item, + inconsistent_x_length_item, + inconsistent_x_units_item, + inconsistent_y_item): + """Spectrum2DCollection.from_spectra with inconsistent input""" + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_units_item] + ) + + with pytest.raises(ValueError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_length_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_y_item] + ) + + def test_indexing(self, quartz_fuzzy_collection, quartz_fuzzy_items): """Check indexing an element, slice and iteration - Individual index should yield corresponding Spectrum2D @@ -109,20 +159,35 @@ def test_indexing(self): - Iteration should yield a series of Spectrum2D """ - # TODO move spectrum load to a common fixture - - spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") - for i in range(3)] - collection = get_spectrum2dcollection("quartz_fuzzy_map.json") - - item_1 = collection[1] + item_1 = quartz_fuzzy_collection[1] assert isinstance(item_1, Spectrum2D) - check_spectrum2d(item_1, spectra[1]) + check_spectrum2d(item_1, quartz_fuzzy_items[1]) - item_1_to_end = collection[1:] + item_1_to_end = quartz_fuzzy_collection[1:] assert isinstance(item_1_to_end, Spectrum2DCollection) - assert item_1_to_end != collection + assert item_1_to_end != quartz_fuzzy_collection - for item, ref in zip(item_1_to_end, spectra[1:]): + for item, ref in zip(item_1_to_end, quartz_fuzzy_items[1:]): assert isinstance(item, Spectrum2D) check_spectrum2d(item, ref) + + def test_collection_methods(self, quartz_fuzzy_collection): + """Check methods from SpectrumCollectionMixin + + These are checked thoroughly for Spectrum1DCollection, but here we + try to ensure the generic implementation works correctly in 2-D + + """ + + total = quartz_fuzzy_collection.sum() + assert isinstance(total, Spectrum2D) + assert total.z_data[3, 3] == sum(spec.z_data[3, 3] + for spec in quartz_fuzzy_collection) + + extended = quartz_fuzzy_collection + quartz_fuzzy_collection + assert len(extended) == 2 * len(quartz_fuzzy_collection) + np.testing.assert_allclose(extended.sum().z_data, total.z_data * 2) + + selection = quartz_fuzzy_collection.select(direction=2, common="yes") + ref_item_2 = get_spectrum2d("quartz_fuzzy_map_2.json") + check_spectrum2d(selection.sum(), ref_item_2) From e5c97561c21a4a7ff0d6d07b2a4be43bd728d510 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Tue, 17 Sep 2024 16:44:15 +0100 Subject: [PATCH 24/37] Fix numpy/pint warnings --- .../euphonic_test/test_spectrum2dcollection.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index 99218d497..e2bb4ddd4 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -43,14 +43,14 @@ def inconsistent_x_units_item(): @pytest.fixture def inconsistent_x_length_item(): item = get_spectrum2d("quartz_fuzzy_map_0.json") - item._x_data = item._x_data[:-2] - item._z_data = item._z_data[:-2, :] + item.x_data = item.x_data[:-2] + item.z_data = item.z_data[:-2, :] return item @pytest.fixture def inconsistent_y_item(): item = get_spectrum2d("quartz_fuzzy_map_0.json") - item._y_data *= 2. + item.y_data = item.y_data * 2. return item def rand_spectrum2d(seed: int = 1, @@ -103,7 +103,9 @@ def test_init_from_numbers(self): for attr, data in [("x_data", x_data), ("y_data", y_data), ("z_data", z_data)]: - np.testing.assert_allclose(getattr(spectrum, attr), data) + assert getattr(spectrum, attr).units == data.units + np.testing.assert_allclose(getattr(spectrum, attr).magnitude, + data.magnitude) assert spectrum.metadata == metadata @@ -115,7 +117,7 @@ def test_from_spectra(self, quartz_fuzzy_collection, quartz_fuzzy_items): for attr in ("x_data", "y_data", "z_data"): new, ref = getattr(collection, attr), getattr(ref_collection, attr) assert new.units == ref.units - np.testing.assert_allclose(new, ref) + np.testing.assert_allclose(new.magnitude, ref.magnitude) if ref_collection.metadata is None: assert collection.metadata is None @@ -186,7 +188,8 @@ def test_collection_methods(self, quartz_fuzzy_collection): extended = quartz_fuzzy_collection + quartz_fuzzy_collection assert len(extended) == 2 * len(quartz_fuzzy_collection) - np.testing.assert_allclose(extended.sum().z_data, total.z_data * 2) + np.testing.assert_allclose(extended.sum().z_data.magnitude, + total.z_data.magnitude * 2) selection = quartz_fuzzy_collection.select(direction=2, common="yes") ref_item_2 = get_spectrum2d("quartz_fuzzy_map_2.json") From 56dfcb923950f7cb3bf8a1793d7a2320b76386b8 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Tue, 17 Sep 2024 16:58:16 +0100 Subject: [PATCH 25/37] Linting --- .../test_spectrum2dcollection.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index e2bb4ddd4..96029bc8e 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -1,3 +1,8 @@ +"""Unit tests for Spectrum2DCollection""" + +# Stop the linter from complaining when pytest fixtures are used idiomatically +# pylint: disable=redefined-outer-name + from typing import Optional import numpy as np @@ -11,37 +16,44 @@ def get_spectrum2dcollection_path(*subpaths): + """Get Spectrum2DCollection reference data path""" return get_data_path('spectrum2dcollection', *subpaths) def get_spectrum2dcollection(json_filename): + """Get Spectrum2DCollection reference data object""" return Spectrum2DCollection.from_json_file( get_spectrum2dcollection_path(json_filename)) @pytest.fixture -def quartz_fuzzy_collection(): +def quartz_fuzzy_collection() -> Spectrum2DCollection: + """Coarsely sampled quartz bands in a few directions""" return get_spectrum2dcollection("quartz_fuzzy_map.json") @pytest.fixture -def quartz_fuzzy_items(): +def quartz_fuzzy_items() -> list[Spectrum2D]: + """Individual spectra corresponding to quartz_fuzzy_collection""" return [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") for i in range(3)] @pytest.fixture -def inconsistent_x_item(): +def inconsistent_x_item() -> Spectrum2D: + """Spectrum with different x values""" item = get_spectrum2d("quartz_fuzzy_map_0.json") item._x_data *= 2. return item @pytest.fixture def inconsistent_x_units_item(): + """Spectrum with different x units""" item = get_spectrum2d("quartz_fuzzy_map_0.json") item.x_data_unit = "1/bohr" return item @pytest.fixture def inconsistent_x_length_item(): + """Spectrum with different number of x values""" item = get_spectrum2d("quartz_fuzzy_map_0.json") item.x_data = item.x_data[:-2] item.z_data = item.z_data[:-2, :] @@ -49,6 +61,7 @@ def inconsistent_x_length_item(): @pytest.fixture def inconsistent_y_item(): + """Spectrum with different y values""" item = get_spectrum2d("quartz_fuzzy_map_0.json") item.y_data = item.y_data * 2. return item @@ -81,20 +94,21 @@ def rand_spectrum2d(seed: int = 1, class TestSpectrum2DCollectionCreation: + """Unit tests for Spectrum2DCollection constructors""" def test_init_from_numbers(self): """Construct Spectrum2DCollection with __init__()""" - N_X = 10 - N_Y = 20 - N_Z = 5 + n_x = 10 + n_y = 20 + n_z = 5 - x_data = ureg.Quantity(np.linspace(0, 100, N_X), "1 / angstrom") - y_data = ureg.Quantity(np.linspace(0, 2000, N_Y), "meV") - z_data = ureg.Quantity(np.random.random((N_Z, N_X, N_Y)), "1 / meV") + x_data = ureg.Quantity(np.linspace(0, 100, n_x), "1 / angstrom") + y_data = ureg.Quantity(np.linspace(0, 2000, n_y), "meV") + z_data = ureg.Quantity(np.random.random((n_z, n_x, n_y)), "1 / meV") metadata = {"flavour": "chocolate", - "line_data": [{"index": i} for i in range(N_Z)]} + "line_data": [{"index": i} for i in range(n_z)]} - x_tick_labels = [(0, "Start"), (N_X - 1, "END")] + x_tick_labels = [(0, "Start"), (n_x - 1, "END")] spectrum = Spectrum2DCollection( x_data, y_data, z_data, @@ -124,6 +138,7 @@ def test_from_spectra(self, quartz_fuzzy_collection, quartz_fuzzy_items): else: assert ref_collection.metadata == collection.metadata + # pylint: disable=R0913 # These fixtures are "too many arguments" def test_from_bad_spectra( self, quartz_fuzzy_items, @@ -153,6 +168,9 @@ def test_from_bad_spectra( quartz_fuzzy_items + [inconsistent_y_item] ) +class TestSpectrum2DCollectionFunctionality: + """Unit test indexing and methods of Spectrum2DCollection""" + def test_indexing(self, quartz_fuzzy_collection, quartz_fuzzy_items): """Check indexing an element, slice and iteration From ade79e1a398d4c144897188f7a00495dbded886d Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 11:14:05 +0100 Subject: [PATCH 26/37] Apply suggestions from code review Make things a little cleaner and more idiomatic Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> --- euphonic/spectra.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 869d3917e..d24a3b88b 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -831,7 +831,7 @@ def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarra raise IndexError(f'index "{item.stop}" out of range') return - if not all([isinstance(i, Integral) for i in item]): + if not all(isinstance(i, Integral) for i in item): raise TypeError( f'Index "{item}" should be an integer, slice ' f'or sequence of ints') @@ -879,9 +879,9 @@ def __add__(self, other: Self) -> Self: def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" - common_metadata = dict( - (key, self.metadata[key]) - for key in set(self.metadata.keys()) - {"line_data",}) + common_metadata = { + key: self.metadata[key] + for key in set(self.metadata.keys()) - {"line_data",}} line_data = self.metadata.get("line_data") if line_data is None: @@ -1039,7 +1039,7 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls: Self, d: dict) -> Self: """Initialise a Spectrum Collection object from dict""" - data_keys = list(f"{dim}_data" for dim in cls._bin_axes) + data_keys = [f"{dim}_data" for dim in cls._bin_axes] data_keys.append(cls._spectrum_data_name()) d = _process_dict(d, @@ -1852,7 +1852,7 @@ def __init__( f'z_data contains {len(z_data)} spectra, but ' f'metadata["line_data"] contains ' f'{len(metadata["line_data"])} entries') - self.metadata = {} if metadata is None else metadata + self.metadata = metadata if metadata is not None else {} def _split_by_indices(self, indices: Sequence[int] | np.ndarray ) -> List[Self]: From 3ecad2c15a2f7815bacd88e4cfbf63fb993a2685 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 11:37:13 +0100 Subject: [PATCH 27/37] Response to review: docstring improvements --- euphonic/spectra.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index d24a3b88b..92ed84e7a 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -823,7 +823,14 @@ def __getitem__( def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarray ) -> None: - """Raise Error if index has inappropriate typing/range""" + """Raise Error if index has inappropriate typing/ranges + + Raises: + IndexError: Slice is not compatible with size of collection + + TypeError: item specification does not have acceptable type; e.g. + a sequence of float or bool was provided when ints are needed. + """ if isinstance(item, Integral): return if isinstance(item, slice): @@ -891,6 +898,28 @@ def iter_metadata(self) -> Generator[OneLineData, None, None]: yield common_metadata | one_line_data def _select_indices(self, **select_key_values) -> list[int]: + """Get indices of items that match metadata query + + The target key-value pairs are a subset of the matching data, e.g. + + self._select_indices(species="Na", weight="coherent") + + will match metadata rows + + {"species": "Na", "weight": "coherent"} + + and + + {"species": "Na", "weight": "coherent", "mass": "22.9898"} + + but not + + {"species": "Na"} + + or + + {"species": "K", "weight": "coherent"} + """ required_metadata = select_key_values.items() indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()] From 8631bc43c110747e0ccf01439a34a2ff63f8d004 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 15:56:44 +0100 Subject: [PATCH 28/37] Numpy style docstring --- euphonic/spectra.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 92ed84e7a..5ffa138a7 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -825,11 +825,15 @@ def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarra ) -> None: """Raise Error if index has inappropriate typing/ranges - Raises: - IndexError: Slice is not compatible with size of collection + Raises + ------ + IndexError + Slice is not compatible with size of collection + + TypeError + item specification does not have acceptable type; e.g. a sequence + of float or bool was provided when ints are needed. - TypeError: item specification does not have acceptable type; e.g. - a sequence of float or bool was provided when ints are needed. """ if isinstance(item, Integral): return From eb808ba2149943ca5279c30501adb7ddff35eeb8 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 16:40:37 +0100 Subject: [PATCH 29/37] More tweaks from review: improve clarity --- euphonic/spectra.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 5ffa138a7..97741f69d 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod import collections import copy -from functools import partial +from functools import partial, reduce import itertools import math import json @@ -17,6 +17,7 @@ from pint import DimensionalityError, Quantity import numpy as np from scipy.ndimage import correlate1d, gaussian_filter +from toolz.dicttoolz import valmap from toolz.itertoolz import groupby, pluck from euphonic import ureg, __version__ @@ -955,10 +956,12 @@ def select(self, **select_key_values: Union[ If no matching spectra are found """ # Convert all items to sequences of possibilities - select_key_values = dict( - (key, (value,)) if isinstance(value, (int, str)) else (key, value) - for key, value in select_key_values.items() - ) + def ensure_sequence(value: int | str | Sequence[int | str] + ) -> Sequence[int | str]: + return (value,) if isinstance(value, (int, str)) else value + + select_key_values = valmap(ensure_sequence, select_key_values) + # Collect indices that match each combination of values selected_indices = [] @@ -989,8 +992,8 @@ def _combine_metadata(all_metadata: LineData) -> Metadata: # Combine all common key/value pairs into new dict combined_metadata = dict( - set(all_metadata[0].items()).intersection( - *[metadata.items() for metadata in all_metadata[1:]])) + reduce(set.intersection, + (set(metadata.items()) for metadata in all_metadata))) # Put all other per-spectrum metadata in line_data line_data = [ @@ -1867,8 +1870,8 @@ def __init__( [Quantity, [list, type(None)], [dict, type(None)]], [(-1, -1, -1), (), ()], ['z_data', 'x_tick_labels', 'metadata']) - nx = z_data.shape[1] - ny = z_data.shape[2] + # First axis corresponds to spectra in collection + _, nx, ny = z_data.shape _check_constructor_inputs( [x_data, y_data], [Quantity, Quantity], From 19f36ab298a4412cfd05c6af285eb83ba49ecd9d Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 16:58:39 +0100 Subject: [PATCH 30/37] Factor out metadata length check to common method --- euphonic/spectra.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 97741f69d..a273f20a6 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1016,6 +1016,25 @@ def _tidy_metadata(self) -> Metadata: combined_line_data.pop("line_data", None) return combined_line_data + def _check_metadata(self) -> None: + """Check self.metadata['line_data'] is consistent with collection size + + Raises + ------ + ValueError + Metadata contains 'line_data' of incorrect length + + """ + if 'line_data' in self.metadata: + collection_size = len(self._get_raw_spectrum_data()) + n_lines = len(self.metadata['line_data']) + + if n_lines != collection_size: + raise ValueError( + f'{self._spectrum_data_name()} contains {collection_size} ' + f'spectra, but metadata["line_data"] contains ' + f'{n_lines} entries') + def group_by(self, *line_data_keys: str) -> Self: """ Group and sum elements of spectral data according to the values @@ -1171,13 +1190,9 @@ def __init__( self._set_data(x_data, 'x') self._set_data(y_data, 'y') self.x_tick_labels = x_tick_labels - if metadata and 'line_data' in metadata.keys(): - if len(metadata['line_data']) != len(y_data): - raise ValueError( - f'y_data contains {len(y_data)} spectra, but ' - f'metadata["line_data"] contains ' - f'{len(metadata["line_data"])} entries') - self.metadata = {} if metadata is None else metadata + + self.metadata = metadata if metadata is not None else {} + self._check_metadata() def _split_by_indices(self, indices: Union[Sequence[int], np.ndarray] @@ -1882,13 +1897,9 @@ def __init__( self._set_data(y_data, 'y') self.x_tick_labels = x_tick_labels self._set_data(z_data, 'z') - if metadata and 'line_data' in metadata.keys(): - if len(metadata['line_data']) != len(z_data): - raise ValueError( - f'z_data contains {len(z_data)} spectra, but ' - f'metadata["line_data"] contains ' - f'{len(metadata["line_data"])} entries') + self.metadata = metadata if metadata is not None else {} + self._check_metadata() def _split_by_indices(self, indices: Sequence[int] | np.ndarray ) -> List[Self]: From 5dda596db654d287da6aa4eefa12cf3e1e9eb74f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 17:04:08 +0100 Subject: [PATCH 31/37] Remove unnecessary parens --- euphonic/spectra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index a273f20a6..3462db5b7 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -1952,7 +1952,7 @@ def _type_check(spectrum): _type_check(spectrum) spectrum_i_raw_data = cls._get_item_raw_data(spectrum) spectrum_i_data_units = cls._get_item_data_unit(spectrum) - assert (spectrum_i_data_units == spectrum_data_units) + assert spectrum_i_data_units == spectrum_data_units for key, ref_bins in bins_data.items(): item_bins = getattr(spectrum, key) From b647cb737153f039e744962e1cc720e02d087b54 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 19 Sep 2024 17:18:43 +0100 Subject: [PATCH 32/37] Easier-to-read dict comprehension with less set magic Suggested by @oerc0122 --- euphonic/spectra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 3462db5b7..8d09b014d 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -893,7 +893,7 @@ def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" common_metadata = { key: self.metadata[key] - for key in set(self.metadata.keys()) - {"line_data",}} + for key in self.metadata if key != "line_data"} line_data = self.metadata.get("line_data") if line_data is None: From aaaa25a2f48b3d85fc0cfd7eb7a95f7a74e78732 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 20 Sep 2024 09:49:34 +0100 Subject: [PATCH 33/37] Use keyfilter from toolz - This avoids some of the repetition in dict comprehensions to remove an element - Here we also slightly rework _combine_metadata so it is clearer what each variable represents. --- euphonic/spectra.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 8d09b014d..41cbbc1e4 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -17,7 +17,7 @@ from pint import DimensionalityError, Quantity import numpy as np from scipy.ndimage import correlate1d, gaussian_filter -from toolz.dicttoolz import valmap +from toolz.dicttoolz import keyfilter, valmap from toolz.itertoolz import groupby, pluck from euphonic import ureg, __version__ @@ -891,9 +891,8 @@ def __add__(self, other: Self) -> Self: def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" - common_metadata = { - key: self.metadata[key] - for key in self.metadata if key != "line_data"} + common_metadata = keyfilter(lambda key: key != "line_data", + self.metadata) line_data = self.metadata.get("line_data") if line_data is None: @@ -991,20 +990,19 @@ def _combine_metadata(all_metadata: LineData) -> Metadata: assert 'line_data' not in metadata.keys() # Combine all common key/value pairs into new dict - combined_metadata = dict( + common_metadata = dict( reduce(set.intersection, (set(metadata.items()) for metadata in all_metadata))) # Put all other per-spectrum metadata in line_data - line_data = [ - {key: value for key, value in metadata.items() - if key not in combined_metadata} - for metadata in all_metadata - ] + line_data = [keyfilter(lambda key: key not in common_metadata, + one_line_data) + for one_line_data in all_metadata] + if any(line_data): - combined_metadata['line_data'] = line_data + return common_metadata | {'line_data': line_data} - return combined_metadata + return common_metadata def _tidy_metadata(self) -> Metadata: """ From 87f32a0b19963ea1485fde6e3c4025f38655d636 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 20 Sep 2024 16:53:45 +0100 Subject: [PATCH 34/37] Simplify _get_item_metadata list comprehension is little clunky but avoids the 1-length special case: cleaner overall --- euphonic/spectra.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 41cbbc1e4..f1ed6e5dd 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -865,10 +865,8 @@ def _get_item_metadata(self, item): # noqa: F811 return metadata_lines[item] if isinstance(item, slice): return self._combine_metadata(metadata_lines[item]) - if len(item) == 1: - return metadata_lines[item[0]] - return self._combine_metadata( - list(itemgetter(*item)(metadata_lines))) + # Item must be some kind of integer sequence + return self._combine_metadata([metadata_lines[i] for i in item]) def copy(self) -> Self: """Get an independent copy of spectrum""" From b682432057539ce9eec8d5c80755f77cfdc65f6a Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 20 Sep 2024 16:55:19 +0100 Subject: [PATCH 35/37] Drop unused import --- euphonic/spectra.py | 1 - 1 file changed, 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index f1ed6e5dd..e14b907e2 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -8,7 +8,6 @@ import math import json from numbers import Integral, Real -from operator import itemgetter from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) from typing_extensions import Self From 0f7ec817b6e503219cceb0474e5461141e0b0bc2 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 26 Sep 2024 11:51:34 +0100 Subject: [PATCH 36/37] More tidying for legibility Via discussion / pair-programming with @oerc0122 - Use native dict comprehension over keyfilter in iter_metadata: it's a bit ugly but no more complicated, and should be easier to read "casually" - Clearer comment re: value-pair combination - Replace a lambda with named partial function and toolz complement --- euphonic/spectra.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index e14b907e2..9dbc2e43d 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -8,6 +8,7 @@ import math import json from numbers import Integral, Real +from operator import contains from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) from typing_extensions import Self @@ -17,6 +18,7 @@ import numpy as np from scipy.ndimage import correlate1d, gaussian_filter from toolz.dicttoolz import keyfilter, valmap +from toolz.functoolz import complement from toolz.itertoolz import groupby, pluck from euphonic import ureg, __version__ @@ -888,8 +890,9 @@ def __add__(self, other: Self) -> Self: def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" - common_metadata = keyfilter(lambda key: key != "line_data", - self.metadata) + common_metadata = {key: value for key, value in self.metadata.items() + if key != "line_data"} + line_data = self.metadata.get("line_data") if line_data is None: @@ -986,14 +989,14 @@ def _combine_metadata(all_metadata: LineData) -> Metadata: for metadata in all_metadata: assert 'line_data' not in metadata.keys() - # Combine all common key/value pairs into new dict + # Combine key-value pairs common to *all* metadata lines into new dict common_metadata = dict( reduce(set.intersection, (set(metadata.items()) for metadata in all_metadata))) # Put all other per-spectrum metadata in line_data - line_data = [keyfilter(lambda key: key not in common_metadata, - one_line_data) + is_common = partial(contains, common_metadata) + line_data = [keyfilter(complement(is_common), one_line_data) for one_line_data in all_metadata] if any(line_data): From ef314a0fe22448b56b3d9a2cfb2bee3160d8f724 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Thu, 26 Sep 2024 14:37:28 +0100 Subject: [PATCH 37/37] Remove redundant .keys() when iterating over dict Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> --- euphonic/spectra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 9dbc2e43d..e26092799 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -987,7 +987,7 @@ def _combine_metadata(all_metadata: LineData) -> Metadata: # This is for combining multiple separate spectrum metadata, # they shouldn't have line_data for metadata in all_metadata: - assert 'line_data' not in metadata.keys() + assert 'line_data' not in metadata # Combine key-value pairs common to *all* metadata lines into new dict common_metadata = dict(