From 35bf3e7a5422764972633bb0eae8943def582808 Mon Sep 17 00:00:00 2001 From: Will Fondrie Date: Wed, 24 Apr 2024 23:19:20 -0700 Subject: [PATCH] Quality of life fixes (#48) * Small qol fixes * Bump changelog --- CHANGELOG.md | 6 ++++++ depthcharge/data/spectrum_datasets.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e10b68d..0b46a2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [v0.4.2] + +### Changed +- The length of `SpectrumDataset` and `AnnotatedSpectrumDataset` is now the number of batches, not the number of spectra. This let's tools like PyTorch Lighting create their progress bars properly. +- Parsing a dataset now no longer requires reading essentially the whole first file. Now the schema is inferred from the first 128 spectra. + ## [v0.4.1] ### Added diff --git a/depthcharge/data/spectrum_datasets.py b/depthcharge/data/spectrum_datasets.py index d3f6da5..8ee983b 100644 --- a/depthcharge/data/spectrum_datasets.py +++ b/depthcharge/data/spectrum_datasets.py @@ -2,7 +2,9 @@ from __future__ import annotations +import copy import logging +import math import uuid from collections.abc import Generator, Iterable from os import PathLike @@ -88,6 +90,10 @@ def __init__( ) -> None: """Initialize a SpectrumDataset.""" self._parse_kwargs = {} if parse_kwargs is None else parse_kwargs + self._init_kwargs = copy.copy(self._parse_kwargs) + self._init_kwargs["batch_size"] = 128 + self._init_kwargs["progress"] = False + self._tmpdir = None if path is None: # Create a random temporary file: @@ -101,7 +107,7 @@ def __init__( # Now parse spectra. if spectra is not None: spectra = utils.listify(spectra) - batch = next(_get_records(spectra, **self._parse_kwargs)) + batch = next(_get_records(spectra, **self._init_kwargs)) lance.write_dataset( _get_records(spectra, **self._parse_kwargs), str(self._path), @@ -137,7 +143,7 @@ def add_spectra( """ spectra = utils.listify(spectra) - batch = next(_get_records(spectra, **self._parse_kwargs)) + batch = next(_get_records(spectra, **self._init_kwargs)) self._dataset = lance.write_dataset( _get_records(spectra, **self._parse_kwargs), self._path, @@ -167,8 +173,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return self._to_tensor(self._dataset.take(utils.listify(idx))) def __len__(self) -> int: - """The number of spectra in the lance dataset.""" - return self._dataset.count_rows() + """The number of batches in the lance dataset.""" + return math.ceil(self._dataset.count_rows() / self.batch_size) def __del__(self) -> None: """Cleanup the temporary directory."""