From 7477d2bf20749e393c437007564bd87519b7e70d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 30 Oct 2024 19:25:20 +0800 Subject: [PATCH] refactor: Optimize batch shuffling implementation for better performance (#252) * improve performance * refactor ShuffleDataLoader * add more assertion * bump version --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/batch_shuffle.rs | 554 ++++++++----------------------------------- src/dataset.rs | 24 +- src/training.rs | 21 +- 5 files changed, 134 insertions(+), 469 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 87ca26c..11103a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "1.4.2" +version = "1.4.3" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 87d9637..5497d56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.4.2" +version = "1.4.3" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 52cbbf9..9ed1501 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -1,70 +1,34 @@ -use burn::data::{ - dataloader::{ - batcher::DynBatcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy, - Progress, - }, - dataset::Dataset, -}; +use std::sync::Mutex; -use rand::{distributions::Standard, prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng}; -use std::{ - marker::PhantomData, - sync::{Arc, Mutex}, -}; +use burn::data::dataloader::batcher::Batcher; +use burn::data::dataloader::{DataLoaderIterator, Progress}; +use burn::prelude::Backend; +use rand::seq::SliceRandom; +use rand::SeedableRng; -use crate::{dataset::FSRSDataset, FSRSItem}; +use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; -pub(crate) struct BatchShuffledDataset { - dataset: Arc, - indices: Vec, - input: PhantomData, +#[derive(Clone)] +pub(crate) struct BatchTensorDataset { + dataset: Vec>, } -impl BatchShuffledDataset { +impl BatchTensorDataset { /// Creates a new shuffled dataset. - pub fn new(dataset: Arc, batch_size: usize, rng: &mut StdRng) -> Self { - let len = dataset.len(); - - // Calculate the number of batches - // 计算批数 - let num_batches = (len + batch_size - 1) / batch_size; - - // Create a vector of batch indices and shuffle it - // 创建一个批数索引的向量并打乱 - let mut batch_indices: Vec<_> = (0..num_batches).collect(); - batch_indices.shuffle(rng); - // info!("batch_indices: {:?}", &batch_indices); - // Generate the corresponding item indices for each shuffled batch - // 为每个打乱的批次生成相应的元素索引 - let mut indices = vec![]; - for batch_index in batch_indices { - let start_index = batch_index * batch_size; - let end_index = (start_index + batch_size).min(len); - indices.extend(start_index..end_index); - } - // info!("indices: {:?}", &indices); - Self { - dataset, - indices, - input: PhantomData, - } - } - - /// Creates a new shuffled dataset with a fixed seed. - pub fn with_seed(dataset: Arc, batch_size: usize, seed: u64) -> Self { - let mut rng = StdRng::seed_from_u64(seed); - Self::new(dataset, batch_size, &mut rng) + pub fn new(dataset: FSRSDataset, batch_size: usize, device: B::Device) -> Self { + let batcher = FSRSBatcher::::new(device); + let dataset = dataset + .items + .chunks(batch_size) + .map(|items| batcher.batch(items.to_vec())) + .collect(); + Self { dataset } } } -impl Dataset for BatchShuffledDataset { - fn get(&self, index: usize) -> Option { - let shuffled_index = self.indices.get(index)?; - // info!( - // "original index: {}, shuffled index: {}", - // index, shuffled_index - // ); - self.dataset.get(*shuffled_index) +impl BatchTensorDataset { + fn get(&self, index: usize) -> Option> { + self.dataset.get(index).cloned() } fn len(&self) -> usize { @@ -72,129 +36,36 @@ impl Dataset for BatchShuffledDataset { } } -/// A data loader that can be used to iterate over a dataset in batches. -pub struct BatchShuffledDataLoader { - strategy: Box>, - dataset: Arc, - batcher: Box>, +pub struct ShuffleDataLoader { + dataset: BatchTensorDataset, rng: Mutex, - batch_size: usize, } -impl BatchShuffledDataLoader { - /// Creates a new batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader - /// iterator is created. - /// - /// # Returns - /// - /// The batch data loader. - pub fn new( - strategy: Box>, - dataset: Arc, - batcher: Box>, - rng: rand::rngs::StdRng, - batch_size: usize, - ) -> Self { +impl ShuffleDataLoader { + pub fn new(dataset: BatchTensorDataset, seed: u64) -> Self { Self { - strategy, dataset, - batcher, - rng: Mutex::new(rng), - batch_size, + rng: Mutex::new(rand::rngs::StdRng::seed_from_u64(seed)), } } } -/// A data loader iterator that can be used to iterate over a data loader. -struct BatchShuffledDataloaderIterator { +pub(crate) struct ShuffleDataLoaderIterator { current_index: usize, - strategy: Box>, - dataset: Arc>, - batcher: Box>, -} - -impl DataLoader for BatchShuffledDataLoader -where - BatchShuffledDataset: Dataset, -{ - fn iter<'a>(&'a self) -> Box + 'a> { - // When starting a new iteration, we first check if the dataloader was created with an rng, - // implying that we should shuffle the dataset beforehand, while advancing the current - // rng to ensure that each new iteration shuffles the dataset differently. - let dataset = Arc::new(BatchShuffledDataset::with_seed( - self.dataset.clone(), - self.batch_size, - self.rng.lock().unwrap().sample(Standard), - )); - Box::new(BatchShuffledDataloaderIterator::new( - self.strategy.clone_dyn(), - dataset, - self.batcher.clone_dyn(), - )) - } - - fn num_items(&self) -> usize { - self.dataset.len() - } + indices: Vec, + dataset: BatchTensorDataset, } -impl BatchShuffledDataloaderIterator -where - BatchShuffledDataset: Dataset, -{ - /// Creates a new batch data loader iterator. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The batch data loader iterator. - pub fn new( - strategy: Box>, - dataset: Arc>, - batcher: Box>, - ) -> Self { +impl ShuffleDataLoaderIterator { + pub(crate) fn new(dataset: BatchTensorDataset, indices: Vec) -> Self { Self { current_index: 0, - strategy, + indices, dataset, - batcher, - } - } -} - -impl Iterator for BatchShuffledDataloaderIterator { - type Item = O; - - fn next(&mut self) -> Option { - while let Some(item) = self.dataset.get(self.current_index) { - self.current_index += 1; - self.strategy.add(item); - - if let Some(items) = self.strategy.batch(false) { - return Some(self.batcher.batch(items)); - } } - - let items = self.strategy.batch(true)?; - - Some(self.batcher.batch(items)) } -} -impl DataLoaderIterator for BatchShuffledDataloaderIterator { - fn progress(&self) -> Progress { + pub(crate) fn progress(&self) -> Progress { Progress { items_processed: self.current_index, items_total: self.dataset.len(), @@ -202,323 +73,106 @@ impl DataLoaderIterator for BatchShuffledDataloaderIterator { } } -/// A builder for data loaders. -pub struct BatchShuffledDataLoaderBuilder { - batcher: Box>, -} +impl Iterator for ShuffleDataLoaderIterator { + type Item = FSRSBatch; -impl BatchShuffledDataLoaderBuilder -where - I: Send + Sync + Clone + std::fmt::Debug + 'static, - O: Send + Clone + std::fmt::Debug + 'static, - BatchShuffledDataset: Dataset, -{ - /// Creates a new data loader builder. - /// - /// # Arguments - /// - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The data loader builder. - pub fn new(batcher: B) -> Self - where - B: DynBatcher + 'static, - { - Self { - batcher: Box::new(batcher), + fn next(&mut self) -> Option { + if let Some(index) = self.indices.get(self.current_index) { + self.current_index += 1; + return self.dataset.get(*index); } + None } +} - /// Builds the data loader. - /// - /// # Arguments - /// - /// * `dataset` - The dataset. - /// - /// # Returns - /// - /// The data loader. - pub fn build( - self, - dataset: FSRSDataset, - batch_size: usize, - seed: u64, - ) -> Arc> { - let dataset = Arc::new(dataset); - - let rng = StdRng::seed_from_u64(seed); - let strategy = Box::new(FixBatchStrategy::new(batch_size)); +impl DataLoaderIterator> for ShuffleDataLoaderIterator { + fn progress(&self) -> Progress { + Progress::new(self.current_index, self.dataset.len()) + } +} - Arc::new(BatchShuffledDataLoader::new( - strategy, - dataset, - self.batcher, - rng, - batch_size, - )) +impl ShuffleDataLoader { + pub(crate) fn iter(&self) -> ShuffleDataLoaderIterator { + let mut indices: Vec<_> = (0..self.dataset.len()).collect(); + indices.shuffle(&mut *self.rng.lock().unwrap()); + ShuffleDataLoaderIterator::new(self.dataset.clone(), indices) } } #[cfg(test)] mod tests { - use burn::backend::{ndarray::NdArrayDevice, NdArray}; + use burn::{ + backend::{ndarray::NdArrayDevice, NdArray}, + tensor::Shape, + }; use super::*; use crate::{ - convertor_tests::anki21_sample_file_converted_to_fsrs, - dataset::{prepare_training_data, FSRSBatcher, FSRSDataset}, - FSRSItem, FSRSReview, + convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::prepare_training_data, }; #[test] - fn batch_shuffle_dataloader() { + fn test_simple_dataloader() { let train_set = anki21_sample_file_converted_to_fsrs(); let (_pre_train_set, train_set) = prepare_training_data(train_set); let dataset = FSRSDataset::from(train_set); let batch_size = 512; - let seed = 42; + let seed = 114514; let device = NdArrayDevice::Cpu; type Backend = NdArray; - let batcher = FSRSBatcher::::new(device); - let dataloader = - BatchShuffledDataLoaderBuilder::new(batcher).build(dataset, batch_size, seed); - let item = dataloader.iter().next().unwrap(); + + let dataset = BatchTensorDataset::::new(dataset, batch_size, device); + let dataloader = ShuffleDataLoader::new(dataset, seed); + let mut iterator = dataloader.iter(); + // dbg!(&iterator.indices); + let batch = iterator.next().unwrap(); assert_eq!( - item.t_historys.shape(), - burn::tensor::Shape { dims: [6, 512] } + batch.t_historys.shape(), + Shape { + dims: [7, batch_size] + } ); - let item2 = dataloader.iter().next().unwrap(); + let batch = iterator.next().unwrap(); assert_eq!( - item2.t_historys.shape(), - burn::tensor::Shape { dims: [4, 512] } + batch.t_historys.shape(), + Shape { + dims: [6, batch_size] + } ); - } - #[test] - fn batch_shuffle() { - let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs())); - let batch_size = 10; - let seed = 42; - let batch_shuffled_dataset = BatchShuffledDataset::with_seed(dataset, batch_size, seed); + let lengths = iterator + .map(|batch| batch.t_historys.shape().dims[0]) + .collect::>(); assert_eq!( - (0..batch_shuffled_dataset.len().min(batch_size)) - .map(|i| batch_shuffled_dataset.get(i).unwrap()) - .collect::>(), - [ - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 2 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 3 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 2 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 3, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - } + lengths, + vec![ + 48, 6, 8, 5, 11, 5, 10, 19, 6, 13, 9, 6, 5, 3, 9, 6, 3, 13, 7, 5, 4, 4, 4, 6, 4, 3, ] ); - } - #[test] - fn item_shuffle() { - use burn::data::dataset::transform::ShuffledDataset; - let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); - let seed = 42; - let shuffled_dataset = ShuffledDataset::with_seed(dataset, seed); - for i in 0..shuffled_dataset.len().min(10) { - dbg!(shuffled_dataset.get(i).unwrap()); - } + let mut iterator = dataloader.iter(); + // dbg!(&iterator.indices); + let batch = iterator.next().unwrap(); + assert_eq!( + batch.t_historys.shape(), + Shape { + dims: [19, batch_size] + } + ); + let batch = iterator.next().unwrap(); + assert_eq!( + batch.t_historys.shape(), + Shape { + dims: [9, batch_size] + } + ); + + let lengths = iterator + .map(|batch| batch.t_historys.shape().dims[0]) + .collect::>(); + assert_eq!( + lengths, + vec![3, 11, 3, 6, 6, 6, 5, 5, 7, 6, 4, 9, 10, 4, 48, 3, 4, 5, 13, 13, 7, 5, 4, 8, 6, 6] + ); } } diff --git a/src/dataset.rs b/src/dataset.rs index bc9a8d1..043b143 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -107,15 +107,25 @@ impl Batcher> for FSRSBatcher { delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); let delta_t = Tensor::from_data( - Data::new(delta_t, Shape { dims: [pad_size] }).convert(), + Data::new( + delta_t, + Shape { + dims: [1, pad_size], + }, + ) + .convert(), &self.device, - ) - .unsqueeze(); + ); let rating = Tensor::from_data( - Data::new(rating, Shape { dims: [pad_size] }).convert(), + Data::new( + rating, + Shape { + dims: [1, pad_size], + }, + ) + .convert(), &self.device, - ) - .unsqueeze(); + ); (delta_t, rating) }) .unzip(); @@ -156,7 +166,7 @@ impl Batcher> for FSRSBatcher { } pub(crate) struct FSRSDataset { - items: Vec, + pub(crate) items: Vec, } impl Dataset for FSRSDataset { diff --git a/src/training.rs b/src/training.rs index f938dfc..d64b714 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,6 +1,6 @@ -use crate::batch_shuffle::BatchShuffledDataLoaderBuilder; +use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader}; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{prepare_training_data, FSRSBatcher, FSRSDataset, FSRSItem}; +use crate::dataset::{prepare_training_data, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::parameter_clipper::parameter_clipper; @@ -8,7 +8,6 @@ use crate::pre_training::{pretrain, smooth_and_fill}; use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; -use burn::data::dataloader::DataLoaderBuilder; use burn::lr_scheduler::LrScheduler; use burn::module::AutodiffModule; use burn::nn::loss::Reduction; @@ -325,17 +324,19 @@ fn train( // Training data let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs; - let batcher_train = FSRSBatcher::::new(device.clone()); - let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train).build( + let batch_dataset = BatchTensorDataset::::new( FSRSDataset::from(train_set), config.batch_size, - config.seed, + device.clone(), ); + let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed); - let batcher_valid = FSRSBatcher::new(device); - let dataloader_valid = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .build(FSRSDataset::from(test_set.clone())); + let batch_dataset = BatchTensorDataset::::new( + FSRSDataset::from(test_set.clone()), + config.batch_size, + device, + ); + let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed); let mut lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate); let interrupter = TrainingInterrupter::new();