Skip to content

Commit

Permalink
Quality of life fixes (#48)
Browse files Browse the repository at this point in the history
* Small qol fixes

* Bump changelog
  • Loading branch information
wfondrie authored Apr 25, 2024
1 parent d46adf1 commit 35bf3e7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [v0.4.2]

### Changed
- The length of `SpectrumDataset` and `AnnotatedSpectrumDataset` is now the number of batches, not the number of spectra. This let's tools like PyTorch Lighting create their progress bars properly.
- Parsing a dataset now no longer requires reading essentially the whole first file. Now the schema is inferred from the first 128 spectra.

## [v0.4.1]

### Added
Expand Down
14 changes: 10 additions & 4 deletions depthcharge/data/spectrum_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import copy
import logging
import math
import uuid
from collections.abc import Generator, Iterable
from os import PathLike
Expand Down Expand Up @@ -88,6 +90,10 @@ def __init__(
) -> None:
"""Initialize a SpectrumDataset."""
self._parse_kwargs = {} if parse_kwargs is None else parse_kwargs
self._init_kwargs = copy.copy(self._parse_kwargs)
self._init_kwargs["batch_size"] = 128
self._init_kwargs["progress"] = False

self._tmpdir = None
if path is None:
# Create a random temporary file:
Expand All @@ -101,7 +107,7 @@ def __init__(
# Now parse spectra.
if spectra is not None:
spectra = utils.listify(spectra)
batch = next(_get_records(spectra, **self._parse_kwargs))
batch = next(_get_records(spectra, **self._init_kwargs))
lance.write_dataset(
_get_records(spectra, **self._parse_kwargs),
str(self._path),
Expand Down Expand Up @@ -137,7 +143,7 @@ def add_spectra(
"""
spectra = utils.listify(spectra)
batch = next(_get_records(spectra, **self._parse_kwargs))
batch = next(_get_records(spectra, **self._init_kwargs))
self._dataset = lance.write_dataset(
_get_records(spectra, **self._parse_kwargs),
self._path,
Expand Down Expand Up @@ -167,8 +173,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
return self._to_tensor(self._dataset.take(utils.listify(idx)))

def __len__(self) -> int:
"""The number of spectra in the lance dataset."""
return self._dataset.count_rows()
"""The number of batches in the lance dataset."""
return math.ceil(self._dataset.count_rows() / self.batch_size)

def __del__(self) -> None:
"""Cleanup the temporary directory."""
Expand Down

0 comments on commit 35bf3e7

Please sign in to comment.