-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from FR-DC/FRML-81
FRML-81 Implement Stratified Sampling
- Loading branch information
Showing
6 changed files
with
201 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Iterator, Any, Sequence | ||
|
||
import pandas as pd | ||
import torch | ||
from sklearn.preprocessing import LabelEncoder | ||
from torch.utils.data import Sampler | ||
|
||
|
||
class RandomStratifiedSampler(Sampler[int]): | ||
def __init__( | ||
self, | ||
targets: Sequence[Any], | ||
num_samples: int | None = None, | ||
replacement: bool = True, | ||
) -> None: | ||
"""Stratified sampling from a dataset, such that each class is | ||
sampled with equal probability. | ||
Examples: | ||
Use this with DataLoader to sample from a dataset in a stratified | ||
fashion. For example:: | ||
ds = TensorDataset(...) | ||
dl = DataLoader( | ||
ds, | ||
batch_size=..., | ||
sampler=RandomStratifiedSampler(), | ||
) | ||
This will use the targets' frequency as the inverse probability | ||
for sampling. For example, if the targets are [0, 0, 1, 2], | ||
then the probability of sampling the | ||
Args: | ||
targets: The targets to stratify by. Must be integers. | ||
num_samples: The number of samples to draw. If None, the | ||
number of samples is equal to the length of the dataset. | ||
""" | ||
super().__init__() | ||
|
||
# Given targets [0, 0, 1] | ||
# bincount = [2, 1] | ||
# 1 / bincount = [0.5, 1] | ||
# 1 / bincount / len(bincount) = [0.25, 0.5] | ||
# The indexing then just projects it to the original targets. | ||
targets_lab = torch.tensor(LabelEncoder().fit_transform(targets)) | ||
self.target_probs: torch.Tensor = ( | ||
1 / (bincount := torch.bincount(targets_lab)) / len(bincount) | ||
)[targets_lab] | ||
|
||
self.num_samples = num_samples if num_samples else len(targets) | ||
self.replacement = replacement | ||
|
||
def __len__(self) -> int: | ||
return self.num_samples | ||
|
||
def __iter__(self) -> Iterator[int]: | ||
"""This should be a generator that yields indices from the dataset.""" | ||
yield from torch.multinomial( | ||
self.target_probs, | ||
num_samples=self.num_samples, | ||
replacement=self.replacement, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
from frdc.train.stratified_sampling import RandomStratifiedSampler | ||
|
||
|
||
def test_stratifed_sampling_has_correct_probs(): | ||
sampler = RandomStratifiedSampler(["A", "A", "B"]) | ||
|
||
assert torch.all(sampler.target_probs == torch.tensor([0.25, 0.25, 0.5])) | ||
|
||
|
||
def test_stratified_sampling_fairly_samples(): | ||
"""This test checks that the stratified sampler works with a dataloader.""" | ||
|
||
# This is a simple example of a dataset with 2 classes. | ||
# The first 2 samples are class 0, the third is class 1. | ||
x = torch.tensor([0, 1, 2]) | ||
y = ["A", "A", "B"] | ||
|
||
# To check that it's truly stratified, we'll sample 1000 times | ||
# then assert that both classes are sampled roughly equally. | ||
|
||
# In this case, the first 2 x should be sampled roughly 250 times, | ||
# and the third x should be sampled roughly 500 times. | ||
|
||
num_samples = 1000 | ||
batch_size = 10 | ||
dl = DataLoader( | ||
TensorDataset(x), | ||
batch_size=batch_size, | ||
sampler=RandomStratifiedSampler(y, num_samples=num_samples), | ||
) | ||
|
||
# Note that when we sample from a TensorDataset, we get a tuple of tensors. | ||
# So we need to unpack the tuple. | ||
x_samples = torch.cat([x for (x,) in dl]) | ||
|
||
assert len(x_samples) == num_samples | ||
assert torch.allclose( | ||
torch.bincount(x_samples), | ||
torch.tensor([250, 250, 500]), | ||
# atol is the absolute tolerance, so the result can differ by 50 | ||
atol=50, | ||
) |