diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index de0fa3c0..164b0bb3 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,18 +93,13 @@ def __post_init__(self): @dataclass class NanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + dataset_folder: Union[str, List[str]] + dataset_weights: Optional[List[float]] = None def __post_init__(self): if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) @dataclass