diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 89e651e0f558..48d993037708 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -15,10 +15,14 @@ """Flower Datasets main package.""" +from flwr_datasets import partitioner, resplitter from flwr_datasets.common.version import package_version as _package_version +from flwr_datasets.federated_dataset import FederatedDataset -from .federated_dataset import FederatedDataset - -__all__ = ["FederatedDataset"] +__all__ = [ + "FederatedDataset", + "partitioner", + "resplitter", +] __version__ = _package_version diff --git a/datasets/flwr_datasets/resplitter/__init__.py b/datasets/flwr_datasets/resplitter/__init__.py new file mode 100644 index 000000000000..f778d2096b76 --- /dev/null +++ b/datasets/flwr_datasets/resplitter/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Resplitter package.""" + + +from .merge_resplitter import MergeResplitter + +__all__ = [ + "MergeResplitter", +] diff --git a/datasets/flwr_datasets/merge_resplitter.py b/datasets/flwr_datasets/resplitter/merge_resplitter.py similarity index 100% rename from datasets/flwr_datasets/merge_resplitter.py rename to datasets/flwr_datasets/resplitter/merge_resplitter.py diff --git a/datasets/flwr_datasets/merge_resplitter_test.py b/datasets/flwr_datasets/resplitter/merge_resplitter_test.py similarity index 98% rename from datasets/flwr_datasets/merge_resplitter_test.py rename to datasets/flwr_datasets/resplitter/merge_resplitter_test.py index 096bd7efac27..ebbdfb4022b0 100644 --- a/datasets/flwr_datasets/merge_resplitter_test.py +++ b/datasets/flwr_datasets/resplitter/merge_resplitter_test.py @@ -21,7 +21,7 @@ import pytest from datasets import Dataset, DatasetDict -from flwr_datasets.merge_resplitter import MergeResplitter +from flwr_datasets.resplitter.merge_resplitter import MergeResplitter class TestResplitter(unittest.TestCase): diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 7badb1445460..49c65e9893a7 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -19,8 +19,8 @@ from typing import Dict, Optional, Tuple, Union, cast from flwr_datasets.common import Resplitter -from flwr_datasets.merge_resplitter import MergeResplitter from flwr_datasets.partitioner import IidPartitioner, Partitioner +from flwr_datasets.resplitter.merge_resplitter import MergeResplitter tested_datasets = [ "mnist",