-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbalancing_sampler.py
36 lines (28 loc) · 1.22 KB
/
balancing_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
''' Sampler object for PyTorch. Returns fixed number of samples per class. '''
from typing import Any, List
import numpy as np
import pandas as pd
import torch.utils.data
import numpy.random
class BalancingSampler(torch.utils.data.Sampler):
def __init__(self, df: pd.DataFrame, images_per_class: int,
num_classes: int) -> None:
self.df = df
self.images_per_class = images_per_class
self.num_classes = num_classes
def __iter__(self) -> Any:
''' Returns: iterator '''
indices: List[int] = []
for class_ in self.df.groupby('landmark_id'):
df = class_[1]
if df.shape[0] >= self.images_per_class:
indices.extend(numpy.random.choice(df.index, self.images_per_class,
replace=False))
else:
indices.extend(df.index)
indices.extend(numpy.random.choice(df.index, self.images_per_class
- df.shape[0], replace=True))
assert len(indices) == len(self)
return iter(np.random.permutation(indices))
def __len__(self) -> int:
return self.images_per_class * self.num_classes