From 3594d44b0f07dadaac0cfb218d69d493545dcd75 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Fri, 20 Dec 2024 12:24:32 +0100 Subject: [PATCH] refactor(datasets) Update type hinting generics based on typing module (#4750) --- datasets/flwr_datasets/federated_dataset.py | 10 ++--- .../flwr_datasets/preprocessor/divider.py | 40 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 8659aa03313b..9d2fdd7319c8 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union import datasets from datasets import Dataset, DatasetDict @@ -113,8 +113,8 @@ def __init__( *, dataset: str, subset: Optional[str] = None, - preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None, - partitioners: Dict[str, Union[Partitioner, int]], + preprocessor: Optional[Union[Preprocessor, dict[str, tuple[str, ...]]]] = None, + partitioners: dict[str, Union[Partitioner, int]], shuffle: bool = True, seed: Optional[int] = 42, **load_dataset_kwargs: Any, @@ -125,7 +125,7 @@ def __init__( self._preprocessor: Optional[Preprocessor] = _instantiate_merger_if_needed( preprocessor ) - self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( + self._partitioners: dict[str, Partitioner] = _instantiate_partitioners( partitioners ) self._check_partitioners_correctness() @@ -241,7 +241,7 @@ def load_split(self, split: str) -> Dataset: return dataset_split @property - def partitioners(self) -> Dict[str, Partitioner]: + def partitioners(self) -> dict[str, Partitioner]: """Dictionary mapping each split to its associated partitioner. The returned partitioners have the splits of the dataset assigned to them. diff --git a/datasets/flwr_datasets/preprocessor/divider.py b/datasets/flwr_datasets/preprocessor/divider.py index 9d7570de4cea..3650a95a5a98 100644 --- a/datasets/flwr_datasets/preprocessor/divider.py +++ b/datasets/flwr_datasets/preprocessor/divider.py @@ -17,7 +17,7 @@ import collections import warnings -from typing import Dict, List, Optional, Union, cast +from typing import Optional, Union, cast import datasets from datasets import DatasetDict @@ -101,28 +101,28 @@ class Divider: def __init__( self, divide_config: Union[ - Dict[str, float], - Dict[str, int], - Dict[str, Dict[str, float]], - Dict[str, Dict[str, int]], + dict[str, float], + dict[str, int], + dict[str, dict[str, float]], + dict[str, dict[str, int]], ], divide_split: Optional[str] = None, drop_remaining_splits: bool = False, ) -> None: - self._single_split_config: Union[Dict[str, float], Dict[str, int]] + self._single_split_config: Union[dict[str, float], dict[str, int]] self._multiple_splits_config: Union[ - Dict[str, Dict[str, float]], Dict[str, Dict[str, int]] + dict[str, dict[str, float]], dict[str, dict[str, int]] ] self._config_type = _determine_config_type(divide_config) self._check_type_correctness(divide_config) if self._config_type == "single-split": self._single_split_config = cast( - Union[Dict[str, float], Dict[str, int]], divide_config + Union[dict[str, float], dict[str, int]], divide_config ) else: self._multiple_splits_config = cast( - Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]], + Union[dict[str, dict[str, float]], dict[str, dict[str, int]]], divide_config, ) self._divide_split = divide_split @@ -141,7 +141,7 @@ def __call__(self, dataset: DatasetDict) -> DatasetDict: def resplit(self, dataset: DatasetDict) -> DatasetDict: """Resplit the dataset according to the configuration.""" resplit_dataset = {} - dataset_splits: List[str] = list(dataset.keys()) + dataset_splits: list[str] = list(dataset.keys()) # Change the "single-split" config to look like "multiple-split" config if self._config_type == "single-split": # First, if the `divide_split` is None determine the split @@ -154,7 +154,7 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict: ) self._divide_split = dataset_splits[0] self._multiple_splits_config = cast( - Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]], + Union[dict[str, dict[str, float]], dict[str, dict[str, int]]], {self._divide_split: self._single_split_config}, ) @@ -226,7 +226,7 @@ def _check_duplicate_splits_in_config(self) -> None: ) def _check_duplicate_splits_in_config_and_original_dataset( - self, dataset_splits: List[str] + self, dataset_splits: list[str] ) -> None: """Check duplicates along the new split values and dataset splits. @@ -303,10 +303,10 @@ def _warn_on_potential_misuse_of_divide_split(self) -> None: def _check_type_correctness( self, divide_config: Union[ - Dict[str, float], - Dict[str, int], - Dict[str, Dict[str, float]], - Dict[str, Dict[str, int]], + dict[str, float], + dict[str, int], + dict[str, dict[str, float]], + dict[str, dict[str, int]], ], ) -> None: assert self._config_type in [ @@ -356,10 +356,10 @@ def _check_type_correctness( def _determine_config_type( config: Union[ - Dict[str, float], - Dict[str, int], - Dict[str, Dict[str, float]], - Dict[str, Dict[str, int]], + dict[str, float], + dict[str, int], + dict[str, dict[str, float]], + dict[str, dict[str, int]], ], ) -> str: """Determine configuration type of `divide_config` based on the dict structure.