From 7323ce1c839062aec4ab7129e43f4efb53c79035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 5 Sep 2024 08:48:25 +0000 Subject: [PATCH] make nanoset compatible with python --- src/nanotron/config/config.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index de0fa3c0..164b0bb3 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,18 +93,13 @@ def __post_init__(self): @dataclass class NanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + dataset_folder: Union[str, List[str]] + dataset_weights: Optional[List[float]] = None def __post_init__(self): if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) @dataclass