Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified/equal sampling in FewShotDataLoader #15

Merged
merged 7 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experiments/02_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main():
train_loader = ComposedDataLoaderFactory.create_composed_dataloader_from_path(
Path(config["train_data_path"]),
RandomFeaturesPandasDataset,
{},
{"total_random_feature_sampling": True},
FewShotDataLoader,
{"support_size": config["support_size"], "query_size": config["query_size"]},
ComposedDataLoader,
Expand Down
102 changes: 98 additions & 4 deletions liltab/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from copy import deepcopy
from torch import Tensor
from torch.utils.data import Dataset
from typing import Iterable, OrderedDict
from typing import Iterable, OrderedDict, Dict, Union

from liltab.data.datasets import PandasDataset, RandomFeaturesPandasDataset


class FewShotDataLoader:
Expand All @@ -15,28 +16,83 @@ class FewShotDataLoader:

def __init__(
self,
dataset: Dataset,
dataset: Union[PandasDataset, RandomFeaturesPandasDataset],
support_size: int,
query_size: int,
n_episodes: int = None,
sample_classes_equally: bool = False,
sample_classes_stratified: bool = False,
):
"""
Args:
dataset (Dataset): dataset to load data from.
dataset (Union[PandasDataset, RandomFeaturesPandasDataset]): dataset to load data from.
support_size (int): size of support set in each episode.
query_size (int): size of query set in each episode.
n_episodes (int, optional): number of episodes.
If none, then iterator is without end. Defaults to None.
sample_classes_equally (bool, optional): If True, then in each iteration gives
in task equal number of observations per class.
Apply only to classification.
sample_classes_stratified (bool, optional): If True, then in each iteration gives
in task stratified number of observations per class.
Apply only to classification.
"""
self.dataset = dataset
self.support_size = support_size
self.query_size = query_size
self.n_episodes = n_episodes
self.sample_classes_equally = sample_classes_equally
self.sample_classes_stratified = sample_classes_stratified
if self.sample_classes_equally and self.sample_classes_stratified:
raise ValueError("Only one of equal or stratified sampling can be used.")

self.curr_episode = 0

self.n_rows = len(self.dataset)

if self.sample_classes_equally or self.sample_classes_stratified:
self.y = dataset.raw_y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step can be done in datasets module with one-hot from pytorch (https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html) and reduce memory allocation (both self.y and self.raw_y in datasets)

self.class_values = np.unique(self.y)
if len(self.class_values) > self.support_size:
raise ValueError(
"When sampling equally the support size should "
"be higher than number of distinct values"
)
if len(self.class_values) > self.query_size:
raise ValueError(
"When sampling equally the query size should "
"be higher than number of distinct values"
)
self.class_values_idx = dict()
for val in self.class_values:
self.class_values_idx[val] = np.where(self.y == val)[0]

if sample_classes_equally:
self._init_samples_per_class_equal()

if sample_classes_stratified:
self._init_samples_per_class_stratified()

def _init_samples_per_class_equal(self):
self.samples_per_class_support = {
class_value: self.support_size // len(self.class_values)
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: self.query_size // len(self.class_values)
for class_value in self.class_values
}

def _init_samples_per_class_stratified(self):
self.samples_per_class_support = {
class_value: int(self.support_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: int(self.query_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}

def __iter__(self):
return deepcopy(self)

Expand All @@ -54,6 +110,44 @@ def __next__(self) -> tuple[Tensor, Tensor, Tensor, Tensor]:
raise StopIteration()
self.curr_episode += 1

if self.sample_classes_equally or self.sample_classes_stratified:
return self._sample_with_custom_proportion_classes()
else:
return self._sample_without_stratified_classes()

def _sample_with_custom_proportion_classes(self):
support_indices = self._generate_stratified_sampling_idx(
self.samples_per_class_support, self.support_size
)
query_indices = self._generate_stratified_sampling_idx(
self.samples_per_class_query, self.query_size
)
support_indices = np.random.permutation(support_indices)
query_indices = np.random.permutation(query_indices)
return *self.dataset[support_indices], *self.dataset[query_indices]

def _generate_stratified_sampling_idx(
self, samples_per_class_dict: Dict[int, np.ndarray], set_size: int
) -> list[int]:
sampled_indices = []
for val, idx in self.class_values_idx.items():
replace = samples_per_class_dict[val] > len(idx)
sampled_indices.extend(
np.random.choice(idx, samples_per_class_dict[val], replace=replace)
)
remaining_to_sample = set_size - len(sampled_indices)
if remaining_to_sample > 0:
available_idx_for_sampling = list(set(range(self.n_rows)) - set(sampled_indices))
replace = len(available_idx_for_sampling) > remaining_to_sample
sampled_indices.extend(
np.random.choice(available_idx_for_sampling, remaining_to_sample, replace=replace)
)

return sampled_indices

def _sample_without_stratified_classes(
self,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
replace = True if self.support_size + self.query_size >= self.n_rows else False
all_drawn_indices = np.random.choice(
self.n_rows, self.support_size + self.query_size, replace=replace
Expand Down
Loading