diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 66219fd0b..869d3917e 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -776,7 +776,7 @@ def sum(self) -> Spectrum: np.sum(self._get_raw_spectrum_data(), axis=0), units=self._get_internal_spectrum_data_unit() ).to(self._get_spectrum_data_unit()) - return Spectrum1D( + return self._item_type( **self._get_bin_kwargs(), **{self._spectrum_data_name(): summed_s_data}, x_tick_labels=copy.copy(self.x_tick_labels), diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py index acc5c46e4..99218d497 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -19,6 +19,40 @@ def get_spectrum2dcollection(json_filename): get_spectrum2dcollection_path(json_filename)) +@pytest.fixture +def quartz_fuzzy_collection(): + return get_spectrum2dcollection("quartz_fuzzy_map.json") + + +@pytest.fixture +def quartz_fuzzy_items(): + return [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") for i in range(3)] + +@pytest.fixture +def inconsistent_x_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._x_data *= 2. + return item + +@pytest.fixture +def inconsistent_x_units_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item.x_data_unit = "1/bohr" + return item + +@pytest.fixture +def inconsistent_x_length_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._x_data = item._x_data[:-2] + item._z_data = item._z_data[:-2, :] + return item + +@pytest.fixture +def inconsistent_y_item(): + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._y_data *= 2. + return item + def rand_spectrum2d(seed: int = 1, x_bins: Optional[Quantity] = None, y_bins: Optional[Quantity] = None, @@ -66,30 +100,17 @@ def test_init_from_numbers(self): x_data, y_data, z_data, x_tick_labels=x_tick_labels, metadata=metadata) - assert spectrum - - def test_init_from_spectra(self): - """Construct collection from a series of Spectrum2D""" - spec_2d = rand_spectrum2d(seed=1) - spec_2d_consistent = rand_spectrum2d().copy() - spec_2d_consistent._z_data *= 2 - spec_2d.metadata["index"] = 2 - - spectrum = Spectrum2DCollection.from_spectra( - [spec_2d, spec_2d_consistent]) - - spec_2d_inconsistent = rand_spectrum2d(seed=2) - with pytest.raises(ValueError): - spectrum = Spectrum2DCollection.from_spectra( - [spec_2d, spec_2d_inconsistent]) - assert spectrum + for attr, data in [("x_data", x_data), + ("y_data", y_data), + ("z_data", z_data)]: + np.testing.assert_allclose(getattr(spectrum, attr), data) - def test_from_spectra(self): - spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") - for i in range(3)] - collection = Spectrum2DCollection.from_spectra(spectra) + assert spectrum.metadata == metadata - ref_collection = get_spectrum2dcollection("quartz_fuzzy_map.json") + def test_from_spectra(self, quartz_fuzzy_collection, quartz_fuzzy_items): + """Use alternate constructor Spectrum2DCollection.from_spectra()""" + collection = Spectrum2DCollection.from_spectra(quartz_fuzzy_items) + ref_collection = quartz_fuzzy_collection for attr in ("x_data", "y_data", "z_data"): new, ref = getattr(collection, attr), getattr(ref_collection, attr) @@ -101,7 +122,36 @@ def test_from_spectra(self): else: assert ref_collection.metadata == collection.metadata - def test_indexing(self): + def test_from_bad_spectra( + self, + quartz_fuzzy_items, + inconsistent_x_item, + inconsistent_x_length_item, + inconsistent_x_units_item, + inconsistent_y_item): + """Spectrum2DCollection.from_spectra with inconsistent input""" + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_units_item] + ) + + with pytest.raises(ValueError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_length_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_y_item] + ) + + def test_indexing(self, quartz_fuzzy_collection, quartz_fuzzy_items): """Check indexing an element, slice and iteration - Individual index should yield corresponding Spectrum2D @@ -109,20 +159,35 @@ def test_indexing(self): - Iteration should yield a series of Spectrum2D """ - # TODO move spectrum load to a common fixture - - spectra = [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") - for i in range(3)] - collection = get_spectrum2dcollection("quartz_fuzzy_map.json") - - item_1 = collection[1] + item_1 = quartz_fuzzy_collection[1] assert isinstance(item_1, Spectrum2D) - check_spectrum2d(item_1, spectra[1]) + check_spectrum2d(item_1, quartz_fuzzy_items[1]) - item_1_to_end = collection[1:] + item_1_to_end = quartz_fuzzy_collection[1:] assert isinstance(item_1_to_end, Spectrum2DCollection) - assert item_1_to_end != collection + assert item_1_to_end != quartz_fuzzy_collection - for item, ref in zip(item_1_to_end, spectra[1:]): + for item, ref in zip(item_1_to_end, quartz_fuzzy_items[1:]): assert isinstance(item, Spectrum2D) check_spectrum2d(item, ref) + + def test_collection_methods(self, quartz_fuzzy_collection): + """Check methods from SpectrumCollectionMixin + + These are checked thoroughly for Spectrum1DCollection, but here we + try to ensure the generic implementation works correctly in 2-D + + """ + + total = quartz_fuzzy_collection.sum() + assert isinstance(total, Spectrum2D) + assert total.z_data[3, 3] == sum(spec.z_data[3, 3] + for spec in quartz_fuzzy_collection) + + extended = quartz_fuzzy_collection + quartz_fuzzy_collection + assert len(extended) == 2 * len(quartz_fuzzy_collection) + np.testing.assert_allclose(extended.sum().z_data, total.z_data * 2) + + selection = quartz_fuzzy_collection.select(direction=2, common="yes") + ref_item_2 = get_spectrum2d("quartz_fuzzy_map_2.json") + check_spectrum2d(selection.sum(), ref_item_2)