Skip to content

Commit

Permalink
refactor(datasets) Update type hinting generics based on typing module (
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Dec 20, 2024
1 parent fddcd17 commit 3594d44
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
10 changes: 5 additions & 5 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Union

import datasets
from datasets import Dataset, DatasetDict
Expand Down Expand Up @@ -113,8 +113,8 @@ def __init__(
*,
dataset: str,
subset: Optional[str] = None,
preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
preprocessor: Optional[Union[Preprocessor, dict[str, tuple[str, ...]]]] = None,
partitioners: dict[str, Union[Partitioner, int]],
shuffle: bool = True,
seed: Optional[int] = 42,
**load_dataset_kwargs: Any,
Expand All @@ -125,7 +125,7 @@ def __init__(
self._preprocessor: Optional[Preprocessor] = _instantiate_merger_if_needed(
preprocessor
)
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
self._partitioners: dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._check_partitioners_correctness()
Expand Down Expand Up @@ -241,7 +241,7 @@ def load_split(self, split: str) -> Dataset:
return dataset_split

@property
def partitioners(self) -> Dict[str, Partitioner]:
def partitioners(self) -> dict[str, Partitioner]:
"""Dictionary mapping each split to its associated partitioner.
The returned partitioners have the splits of the dataset assigned to them.
Expand Down
40 changes: 20 additions & 20 deletions datasets/flwr_datasets/preprocessor/divider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import collections
import warnings
from typing import Dict, List, Optional, Union, cast
from typing import Optional, Union, cast

import datasets
from datasets import DatasetDict
Expand Down Expand Up @@ -101,28 +101,28 @@ class Divider:
def __init__(
self,
divide_config: Union[
Dict[str, float],
Dict[str, int],
Dict[str, Dict[str, float]],
Dict[str, Dict[str, int]],
dict[str, float],
dict[str, int],
dict[str, dict[str, float]],
dict[str, dict[str, int]],
],
divide_split: Optional[str] = None,
drop_remaining_splits: bool = False,
) -> None:
self._single_split_config: Union[Dict[str, float], Dict[str, int]]
self._single_split_config: Union[dict[str, float], dict[str, int]]
self._multiple_splits_config: Union[
Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]
dict[str, dict[str, float]], dict[str, dict[str, int]]
]

self._config_type = _determine_config_type(divide_config)
self._check_type_correctness(divide_config)
if self._config_type == "single-split":
self._single_split_config = cast(
Union[Dict[str, float], Dict[str, int]], divide_config
Union[dict[str, float], dict[str, int]], divide_config
)
else:
self._multiple_splits_config = cast(
Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]],
Union[dict[str, dict[str, float]], dict[str, dict[str, int]]],
divide_config,
)
self._divide_split = divide_split
Expand All @@ -141,7 +141,7 @@ def __call__(self, dataset: DatasetDict) -> DatasetDict:
def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the configuration."""
resplit_dataset = {}
dataset_splits: List[str] = list(dataset.keys())
dataset_splits: list[str] = list(dataset.keys())
# Change the "single-split" config to look like "multiple-split" config
if self._config_type == "single-split":
# First, if the `divide_split` is None determine the split
Expand All @@ -154,7 +154,7 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
)
self._divide_split = dataset_splits[0]
self._multiple_splits_config = cast(
Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]],
Union[dict[str, dict[str, float]], dict[str, dict[str, int]]],
{self._divide_split: self._single_split_config},
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def _check_duplicate_splits_in_config(self) -> None:
)

def _check_duplicate_splits_in_config_and_original_dataset(
self, dataset_splits: List[str]
self, dataset_splits: list[str]
) -> None:
"""Check duplicates along the new split values and dataset splits.
Expand Down Expand Up @@ -303,10 +303,10 @@ def _warn_on_potential_misuse_of_divide_split(self) -> None:
def _check_type_correctness(
self,
divide_config: Union[
Dict[str, float],
Dict[str, int],
Dict[str, Dict[str, float]],
Dict[str, Dict[str, int]],
dict[str, float],
dict[str, int],
dict[str, dict[str, float]],
dict[str, dict[str, int]],
],
) -> None:
assert self._config_type in [
Expand Down Expand Up @@ -356,10 +356,10 @@ def _check_type_correctness(

def _determine_config_type(
config: Union[
Dict[str, float],
Dict[str, int],
Dict[str, Dict[str, float]],
Dict[str, Dict[str, int]],
dict[str, float],
dict[str, int],
dict[str, dict[str, float]],
dict[str, dict[str, int]],
],
) -> str:
"""Determine configuration type of `divide_config` based on the dict structure.
Expand Down

0 comments on commit 3594d44

Please sign in to comment.