Skip to content

Commit

Permalink
Add DataloadeConfig to Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Apr 15, 2024
1 parent b9b999c commit be7d065
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def get(self, params: ParamsT) -> Optimizer:
)


class DataloaderConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True


class ModelConfig(BaseModel):
# If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model
Expand All @@ -167,6 +178,7 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig()
dataloader: DataloaderConfig = DataloaderConfig()

model_config = ConfigDict(extra="forbid")

Expand Down
11 changes: 10 additions & 1 deletion src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,17 @@ def dataset(self) -> Dataset[Batch]:

@cached_property
def dataloader(self) -> DataLoader[Any]:
config = self.config.dataloader
return DataLoader(
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn
dataset=self.dataset,
batch_size=self.config.training.batch_size,
collate_fn=self.collate_fn,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
pin_memory=config.pin_memory,
shuffle=config.shuffle,
drop_last=config.drop_last,
)

@abstractmethod
Expand Down

0 comments on commit be7d065

Please sign in to comment.