Skip to content

Commit

Permalink
Add fnames_per_batch argument to HDF5Dataset (#191)
Browse files Browse the repository at this point in the history
* add option for specifying fnames per batch

* add tests for fname limit

* add check for fnames > files_per_batch

* add remove print statement hook

* fix tests by increasing files sampled
  • Loading branch information
EthanMarx authored Jan 31, 2025
1 parent 03beee4 commit b0548be
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ repos:
- id: poetry-check
- id: poetry-lock
args: [--check, --no-update]
- repo: https://github.com/dhruvmanila/remove-print-statements
rev: v0.5.2
hooks:
- id: remove-print-statements
40 changes: 35 additions & 5 deletions ml4gw/dataloading/hdf5_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Sequence, Union
from typing import Optional, Sequence, Union

import h5py
import numpy as np
Expand Down Expand Up @@ -50,6 +50,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
channel. The latter setting limits the amount of
entropy in the effective dataset, but can provide
over 2x improvement in total throughput.
num_files_per_batch:
The number of unique files from which to sample
batch elements each epoch. If left as `None`,
will use all available files. Useful when reading
from many files is bottlenecking dataloading.
"""

def __init__(
Expand All @@ -60,20 +67,29 @@ def __init__(
batch_size: int,
batches_per_epoch: int,
coincident: Union[bool, str],
num_files_per_batch: Optional[int] = None,
) -> None:
if not isinstance(coincident, bool) and coincident != "files":
raise ValueError(
"coincident must be either a boolean or 'files', "
"got unrecognized value {}".format(coincident)
)

self.fnames = fnames
self.fnames = np.array(fnames)
self.channels = channels
self.num_channels = len(channels)
self.kernel_size = kernel_size
self.batch_size = batch_size
self.batches_per_epoch = batches_per_epoch
self.coincident = coincident
self.num_files_per_batch = (
len(fnames) if num_files_per_batch is None else num_files_per_batch
)
if self.num_files_per_batch > len(fnames):
raise ValueError(
f"Number of files per batch ({self.num_files_per_batch}) "
f"cannot exceed number of files ({len(fnames)}) "
)

self.sizes = {}
for fname in self.fnames:
Expand All @@ -85,23 +101,37 @@ def __init__(
"without using chunked storage. This can have "
"severe performance impacts at data loading time. "
"If you need faster loading, try re-generating "
"your datset with chunked storage turned on.".format(
"your dataset with chunked storage turned on.".format(
fname
),
category=ContiguousHdf5Warning,
)

self.sizes[fname] = len(dset)

total = sum(self.sizes.values())
self.probs = np.array([i / total for i in self.sizes.values()])

def __len__(self) -> int:
return self.batches_per_epoch

def sample_fnames(self, size) -> np.ndarray:
return np.random.choice(
self.fnames,
# first, randomly select `self.num_files_per_batch`
# file indices based on their probabilities
fname_indices = np.arange(len(self.fnames))
fname_indices = np.random.choice(
fname_indices,
p=self.probs,
size=(self.num_files_per_batch),
replace=False,
)
# now renormalize the probabilities, and sample
# the requested size from this subset of files
probs = self.probs[fname_indices]
probs /= probs.sum()
return np.random.choice(
self.fnames[fname_indices],
p=probs,
size=size,
replace=True,
)
Expand Down
17 changes: 17 additions & 0 deletions tests/dataloading/test_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ def test_sample_fnames(self, dataset):
counts[fname.name] += 1
assert counts["a.h5"] > counts["b.h5"]

# when not specifying fnames per batch,
# ensure all 3 files are sampled
fnames = dataset.sample_fnames((10000,))
assert len(np.unique(fnames)) == 3

# override fnames per batch for testing
dataset.num_files_per_batch = 1
fnames = dataset.sample_fnames(size=(10,))
assert len(np.unique(fnames)) == 1

# use large enough size (1000)
# such that extremely likely
# each file sampled at least once
dataset.num_files_per_batch = 2
fnames = dataset.sample_fnames(size=(1000,))
assert len(np.unique(fnames)) == 2

def test_sample_batch(self, dataset, kernel_size, coincident):
x = dataset.sample_batch()
assert x.shape == (128, 2, kernel_size)
Expand Down

0 comments on commit b0548be

Please sign in to comment.