From b8be2e28f7aca720361e97de2cb7976f580b2ed2 Mon Sep 17 00:00:00 2001 From: Will Fondrie Date: Mon, 29 Apr 2024 15:19:06 -0700 Subject: [PATCH] Remove `__len__()` for `IterableDataset` classes (#50) * Removed len and simplified * Update warning * Fix warning * Remove unintentional flycheck files --- CHANGELOG.md | 7 ++++++ depthcharge/data/parsers.py | 7 +++--- depthcharge/data/spectrum_datasets.py | 28 ++++++++++----------- tests/unit_tests/test_data/test_datasets.py | 8 +++--- tests/unit_tests/test_data/test_loaders.py | 2 +- 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b744cc..f91b65a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [v0.4.4] + +### Changed +- Partially revert length changes to `SpectrumDataset` and `AnnotatedSpectrumDataset`. We removed `__len__` from both due to problems with PyTorch Lightning compatibility. +- Simplify dataset code by removing redundancy with `lance.pytorch.LanceDatset`. +- Improved warning message for skipped spectra. + ## [v0.4.3] ### Changed diff --git a/depthcharge/data/parsers.py b/depthcharge/data/parsers.py index 491f408..2d9c3a0 100644 --- a/depthcharge/data/parsers.py +++ b/depthcharge/data/parsers.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from os import PathLike @@ -223,10 +224,10 @@ def iter_batches(self, batch_size: int | None) -> pa.RecordBatch: yield self._yield_batch() if n_skipped: - LOGGER.warning( - "Skipped %d spectra with invalid information", n_skipped + warnings.warn( + f"Skipped {n_skipped} spectra with invalid information." + f"Last error was: \n {str(last_exc)}" ) - LOGGER.debug("Last error: %s", str(last_exc)) def _update_batch(self, entry: dict) -> None: """Update the batch. diff --git a/depthcharge/data/spectrum_datasets.py b/depthcharge/data/spectrum_datasets.py index 71bae2e..c394332 100644 --- a/depthcharge/data/spectrum_datasets.py +++ b/depthcharge/data/spectrum_datasets.py @@ -4,7 +4,6 @@ import copy import logging -import math import uuid from collections.abc import Generator, Iterable from os import PathLike @@ -77,6 +76,8 @@ class SpectrumDataset(LanceDataset): ---------- peak_files : list of str path : Path + n_spectra : int + dataset : lance.LanceDataset """ @@ -118,11 +119,11 @@ def __init__( elif not self._path.exists(): raise ValueError("No spectra were provided") - self._dataset = lance.dataset(str(self._path)) + dataset = lance.dataset(str(self._path)) if "to_tensor_fn" not in kwargs: kwargs["to_tensor_fn"] = self._to_tensor - super().__init__(self._dataset, batch_size, **kwargs) + super().__init__(dataset, batch_size, **kwargs) def add_spectra( self, @@ -144,7 +145,7 @@ def add_spectra( """ spectra = utils.listify(spectra) batch = next(_get_records(spectra, **self._init_kwargs)) - self._dataset = lance.write_dataset( + self.dataset = lance.write_dataset( _get_records(spectra, **self._parse_kwargs), self._path, mode="append", @@ -170,26 +171,23 @@ def __getitem__(self, idx: int) -> dict[str, Any]: PyTorch tensors if the nested data type is compatible. """ - return self._to_tensor(self._dataset.take(utils.listify(idx))) - - def __len__(self) -> int: - """The number of batches in the lance dataset.""" - num = self._dataset.count_rows() - if self.samples: - num = min(self.samples, num) - - return math.ceil(num / self.batch_size) + return self._to_tensor(self.dataset.take(utils.listify(idx))) def __del__(self) -> None: """Cleanup the temporary directory.""" if self._tmpdir is not None: self._tmpdir.cleanup() + @property + def n_spectra(self) -> int: + """The number of spectra in the Lance dataset.""" + return self.dataset.count_rows() + @property def peak_files(self) -> list[str]: """The files currently in the lance dataset.""" return ( - self._dataset.to_table(columns=["peak_file"]) + self.dataset.to_table(columns=["peak_file"]) .column(0) .unique() .to_pylist() @@ -320,6 +318,8 @@ class AnnotatedSpectrumDataset(SpectrumDataset): ---------- peak_files : list of str path : Path + n_spectra : int + dataset : lance.LanceDataset tokenizer : PeptideTokenizer The tokenizer for the annotations. annotations : str diff --git a/tests/unit_tests/test_data/test_datasets.py b/tests/unit_tests/test_data/test_datasets.py index b1bd1f2..2302d87 100644 --- a/tests/unit_tests/test_data/test_datasets.py +++ b/tests/unit_tests/test_data/test_datasets.py @@ -28,10 +28,10 @@ def tokenizer(): def test_addition(mgf_small, tmp_path): """Testing adding a file.""" dataset = SpectrumDataset(mgf_small, path=tmp_path / "test", batch_size=1) - assert len(dataset) == 2 + assert dataset.n_spectra == 2 dataset = dataset.add_spectra(mgf_small) - assert len(dataset) == 4 + assert dataset.n_spectra == 4 def test_indexing(tokenizer, mgf_small, tmp_path): @@ -197,7 +197,7 @@ def test_pickle(tokenizer, tmp_path, mgf_small): with pkl_file.open("rb") as pkl: loaded = pickle.load(pkl) - assert len(dataset) == len(loaded) + assert dataset.n_spectra == loaded.n_spectra dataset = AnnotatedSpectrumDataset( [mgf_small], @@ -214,4 +214,4 @@ def test_pickle(tokenizer, tmp_path, mgf_small): with pkl_file.open("rb") as pkl: loaded = pickle.load(pkl) - assert len(dataset) == len(loaded) + assert dataset.n_spectra == loaded.n_spectra diff --git a/tests/unit_tests/test_data/test_loaders.py b/tests/unit_tests/test_data/test_loaders.py index 7b2c48e..ea7bb02 100644 --- a/tests/unit_tests/test_data/test_loaders.py +++ b/tests/unit_tests/test_data/test_loaders.py @@ -37,7 +37,7 @@ def test_spectrum_loader_samples(mgf_small, tmp_path, samples, batches): ) loaded = list(DataLoader(dset)) - assert len(dset) == batches + assert dset.n_spectra == 2 assert len(loaded) == batches