Data Samplers generate balanced batches for smooth training.
A well balanced batch is a batch that contains at least 2 examples for each class present in the batch.
Having well balanced batches is important for many types of similarity learning including contrastive learning because contrastive losses require at least two examples (and sometimes more) to be able to compute distances between the embeddings.
To address this need, TensorFlow Similarity provides data samplers for various types of datasets that:
- Ensure that batches contain at least N examples of each class present in the batch.
- Support restricting the batches to a subset of the classes present in the dataset.
-
class MultiShotMemorySampler
: Base object for fitting to a sequence of data, such as a dataset. -
class SingleShotMemorySampler
: Base object for fitting to a sequence of data, such as a dataset. -
class TFDatasetMultiShotMemorySampler
: Base object for fitting to a sequence of data, such as a dataset.
-
TFRecordDatasetSampler(...)
: Create a TFRecordDataset based sampler. -
select_examples(...)
: Randomly select at most N examples per class