Skip to content

Commit

Permalink
Reflect samples argument in SpectrumDataset lengths (#49)
Browse files Browse the repository at this point in the history
* Fix length of dataset

* Bump changelog

* Add condition when samples is 0
  • Loading branch information
wfondrie authored Apr 26, 2024
1 parent 35bf3e7 commit 15d52f4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [v0.4.3]

### Changed
- Length of the `SpectrumDataset` and `AnnotatedSpectrumDataset` now reflect the `samples` parameter of the `lance.pytorch.LanceDataset` parent class.

## [v0.4.2]

### Changed
Expand Down
6 changes: 5 additions & 1 deletion depthcharge/data/spectrum_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def __getitem__(self, idx: int) -> dict[str, Any]:

def __len__(self) -> int:
"""The number of batches in the lance dataset."""
return math.ceil(self._dataset.count_rows() / self.batch_size)
num = self._dataset.count_rows()
if self.samples:
num = min(self.samples, num)

return math.ceil(num / self.batch_size)

def __del__(self) -> None:
"""Cleanup the temporary directory."""
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/test_data/test_loaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test PyTorch DataLoaders."""

import pyarrow as pa
import pytest
import torch
from torch.utils.data import DataLoader

Expand All @@ -25,6 +26,21 @@ def test_spectrum_loader(mgf_small, tmp_path):
assert isinstance(batch["mz_array"], torch.Tensor)


@pytest.mark.parametrize(["samples", "batches"], [(1, 1), (50, 2), (4, 2)])
def test_spectrum_loader_samples(mgf_small, tmp_path, samples, batches):
"""Test sampling."""
dset = SpectrumDataset(
mgf_small,
batch_size=1,
path=tmp_path / "test",
samples=samples,
)

loaded = list(DataLoader(dset))
assert len(dset) == batches
assert len(loaded) == batches


def test_streaming_spectrum_loader(mgf_small, tmp_path):
"""Test streaming spectra."""
streamer = StreamingSpectrumDataset(mgf_small, batch_size=2)
Expand Down

0 comments on commit 15d52f4

Please sign in to comment.