From 1f9fa755d005f1df4aa01366ca482385b9520de8 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 23 Nov 2023 09:15:15 +0100 Subject: [PATCH 1/2] Decouple client callable (#2393) --- examples/mt-pytorch-callable/README.md | 49 +++++++ examples/mt-pytorch-callable/client.py | 123 ++++++++++++++++ examples/mt-pytorch-callable/driver.py | 25 ++++ examples/mt-pytorch-callable/pyproject.toml | 16 ++ examples/mt-pytorch-callable/requirements.txt | 4 + examples/mt-pytorch-callable/run.sh | 20 +++ examples/mt-pytorch-callable/server.py | 25 ++++ src/py/flwr/__init__.py | 3 +- src/py/flwr/client/app.py | 84 ++++++++--- src/py/flwr/client/flower.py | 138 ++++++++++++++++++ src/py/flwr/flower/__init__.py | 26 ++++ 11 files changed, 495 insertions(+), 18 deletions(-) create mode 100644 examples/mt-pytorch-callable/README.md create mode 100644 examples/mt-pytorch-callable/client.py create mode 100644 examples/mt-pytorch-callable/driver.py create mode 100644 examples/mt-pytorch-callable/pyproject.toml create mode 100644 examples/mt-pytorch-callable/requirements.txt create mode 100755 examples/mt-pytorch-callable/run.sh create mode 100644 examples/mt-pytorch-callable/server.py create mode 100644 src/py/flwr/client/flower.py create mode 100644 src/py/flwr/flower/__init__.py diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md new file mode 100644 index 000000000000..65ef000c26f2 --- /dev/null +++ b/examples/mt-pytorch-callable/README.md @@ -0,0 +1,49 @@ +# Deploy ๐Ÿงช + +๐Ÿงช = this page covers experimental features that might change in future versions of Flower + +This how-to guide describes the deployment of a long-running Flower server. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +โ””โ”€โ”€ client.py +โ”œโ”€โ”€ driver.py +โ”œโ”€โ”€ requirements.txt +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Start the long-running Flower server + +```bash +flower-server --insecure +``` + +## Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +## Start the Driver script + +```bash +python driver.py +``` diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py new file mode 100644 index 000000000000..6f9747784ae0 --- /dev/null +++ b/examples/mt-pytorch-callable/client.py @@ -0,0 +1,123 @@ +import warnings +from collections import OrderedDict + +import flwr as fl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + + +# ############################################################################# +# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader +# ############################################################################# + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, epochs): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + for _ in range(epochs): + for images, labels in tqdm(trainloader): + optimizer.zero_grad() + criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + optimizer.step() + + +def test(net, testloader): + """Validate the model on the test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +# ############################################################################# +# 2. Federation of the pipeline with Flower +# ############################################################################# + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + self.set_parameters(parameters) + train(net, trainloader, epochs=1) + return self.get_parameters(config={}), len(trainloader.dataset), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """.""" + return FlowerClient().to_client() + + +# To run this: `flower-client --callable client:flower` +flower = fl.flower.Flower( + client_fn=client_fn, +) + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:9092", + client=FlowerClient().to_client(), + transport="grpc-rere", + ) diff --git a/examples/mt-pytorch-callable/driver.py b/examples/mt-pytorch-callable/driver.py new file mode 100644 index 000000000000..1248672b6813 --- /dev/null +++ b/examples/mt-pytorch-callable/driver.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower driver +fl.driver.start_driver( + server_address="0.0.0.0:9091", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/mt-pytorch-callable/pyproject.toml b/examples/mt-pytorch-callable/pyproject.toml new file mode 100644 index 000000000000..0d1a91836006 --- /dev/null +++ b/examples/mt-pytorch-callable/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "quickstart-pytorch" +version = "0.1.0" +description = "PyTorch Federated Learning Quickstart with Flower" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } +torch = "1.13.1" +torchvision = "0.14.1" +tqdm = "4.65.0" diff --git a/examples/mt-pytorch-callable/requirements.txt b/examples/mt-pytorch-callable/requirements.txt new file mode 100644 index 000000000000..797ca6db6244 --- /dev/null +++ b/examples/mt-pytorch-callable/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 diff --git a/examples/mt-pytorch-callable/run.sh b/examples/mt-pytorch-callable/run.sh new file mode 100755 index 000000000000..d2bf34f834b1 --- /dev/null +++ b/examples/mt-pytorch-callable/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +# Download the CIFAR-10 dataset +python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/mt-pytorch-callable/server.py b/examples/mt-pytorch-callable/server.py new file mode 100644 index 000000000000..fe691a88aba0 --- /dev/null +++ b/examples/mt-pytorch-callable/server.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/src/py/flwr/__init__.py b/src/py/flwr/__init__.py index d3cbf00747a4..e05799280339 100644 --- a/src/py/flwr/__init__.py +++ b/src/py/flwr/__init__.py @@ -17,12 +17,13 @@ from flwr.common.version import package_version as _package_version -from . import client, common, driver, server, simulation +from . import client, common, driver, flower, server, simulation __all__ = [ "client", "common", "driver", + "flower", "server", "simulation", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 0013b74c631c..b39dbbfc33c0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -22,6 +22,7 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client +from flwr.client.flower import Bwd, Flower, Fwd from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address @@ -32,13 +33,15 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.logger import log +from flwr.common.logger import log, warn_experimental_feature from flwr.proto.task_pb2 import TaskIns, TaskRes +from .flower import load_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response -from .message_handler.message_handler import handle, handle_control_message +from .message_handler.message_handler import handle_control_message from .numpy_client import NumPyClient +from .workload_state import WorkloadState def run_client() -> None: @@ -48,6 +51,22 @@ def run_client() -> None: args = _parse_args_client().parse_args() print(args.server) + print(args.callable_dir) + print(args.callable) + + callable_dir = args.callable_dir + if callable_dir is not None: + sys.path.insert(0, callable_dir) + + def _load() -> Flower: + flower: Flower = load_callable(args.callable) + return flower + + return start_client( + server_address=args.server, + load_callable_fn=_load, + transport="grpc-rere", # Only + ) def _parse_args_client() -> argparse.ArgumentParser: @@ -58,8 +77,18 @@ def _parse_args_client() -> argparse.ArgumentParser: parser.add_argument( "--server", - help="Server address", default="0.0.0.0:9092", + help="Server address", + ) + parser.add_argument( + "--callable", + help="For example: `client:flower` or `project.package.module:wrapper.flower`", + ) + parser.add_argument( + "--callable-dir", + default="", + help="Add specified directory to the PYTHONPATH and load callable from there." + " Default: current working directory.", ) return parser @@ -84,6 +113,7 @@ def _check_actionable_client( def start_client( *, server_address: str, + load_callable_fn: Optional[Callable[[], Flower]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -98,6 +128,8 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + load_callable_fn : Optional[Callable[[], Flower]] (default: None) + ... client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -146,20 +178,31 @@ class `flwr.client.Client` (default: None) """ event(EventType.START_CLIENT_ENTER) - _check_actionable_client(client, client_fn) + if load_callable_fn is None: + _check_actionable_client(client, client_fn) - if client_fn is None: - # Wrap `Client` instance in `client_fn` - def single_client_factory( - cid: str, # pylint: disable=unused-argument - ) -> Client: - if client is None: # Added this to keep mypy happy - raise Exception( - "Both `client_fn` and `client` are `None`, but one is required" - ) - return client # Always return the same instance + if client_fn is None: + # Wrap `Client` instance in `client_fn` + def single_client_factory( + cid: str, # pylint: disable=unused-argument + ) -> Client: + if client is None: # Added this to keep mypy happy + raise Exception( + "Both `client_fn` and `client` are `None`, but one is required" + ) + return client # Always return the same instance + + client_fn = single_client_factory + + def _load_app() -> Flower: + return Flower(client_fn=client_fn) - client_fn = single_client_factory + load_callable_fn = _load_app + else: + warn_experimental_feature("`load_callable_fn`") + + # At this point, only `load_callable_fn` should be used + # Both `client` and `client_fn` must not be used directly # Initialize connection context manager connection, address = _init_connection(transport, server_address) @@ -190,11 +233,18 @@ def single_client_factory( send(task_res) break + # Load app + app: Flower = load_callable_fn() + # Handle task message - task_res = handle(client_fn, task_ins) + fwd_msg: Fwd = Fwd( + task_ins=task_ins, + state=WorkloadState(state={}), + ) + bwd_msg: Bwd = app(fwd=fwd_msg) # Send - send(task_res) + send(bwd_msg.task_res) # Unregister node if delete_node is not None: diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py new file mode 100644 index 000000000000..9eeb41887e24 --- /dev/null +++ b/src/py/flwr/client/flower.py @@ -0,0 +1,138 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable.""" + + +import importlib +from dataclasses import dataclass +from typing import Callable, cast + +from flwr.client.message_handler.message_handler import handle +from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns, TaskRes + + +@dataclass +class Fwd: + """.""" + + task_ins: TaskIns + state: WorkloadState + + +@dataclass +class Bwd: + """.""" + + task_res: TaskRes + state: WorkloadState + + +FlowerCallable = Callable[[Fwd], Bwd] + + +class Flower: + """Flower callable. + + Examples + -------- + Assuming a typical client implementation in `FlowerClient`, you can wrap it in a + Flower callable as follows: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid): + >>> return FlowerClient().to_client() + >>> + >>> flower = Flower(client_fn) + + If the above code is in a Python module called `client`, it can be started as + follows: + + >>> flower-client --callable client:flower + + In this `client:flower` example, `client` refers to the Python module in which the + previous code lives in. `flower` refers to the global attribute `flower` that points + to an object of type `Flower` (a Flower callable). + """ + + def __init__( + self, + client_fn: ClientFn, # Only for backward compatibility + ) -> None: + self.client_fn = client_fn + + def __call__(self, fwd: Fwd) -> Bwd: + """.""" + # Execute the task + task_res = handle( + client_fn=self.client_fn, + task_ins=fwd.task_ins, + ) + return Bwd( + task_res=task_res, + state=WorkloadState(state={}), + ) + + +class LoadCallableError(Exception): + """.""" + + +def load_callable(module_attribute_str: str) -> Flower: + """Load the `Flower` object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `client:flower` and `project.package.module:wrapper.flower`. It + must refer to a module on the PYTHONPATH, the module needs to have the specified + attribute, and the attribute must be of type `Flower`. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + raise LoadCallableError( + f"Missing module in {module_attribute_str}", + ) from None + if not attributes_str: + raise LoadCallableError( + f"Missing attribute in {module_attribute_str}", + ) from None + + # Load module + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise LoadCallableError( + f"Unable to load module {module_str}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise LoadCallableError( + f"Unable to load attribute {attributes_str} from module {module_str}", + ) from None + + # Check type + if not isinstance(attribute, Flower): + raise LoadCallableError( + f"Attribute {attributes_str} is not of type {Flower}", + ) from None + + return cast(Flower, attribute) diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/flower/__init__.py new file mode 100644 index 000000000000..090c78062d02 --- /dev/null +++ b/src/py/flwr/flower/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable package.""" + + +from flwr.client.flower import Bwd as Bwd +from flwr.client.flower import Flower as Flower +from flwr.client.flower import Fwd as Fwd + +__all__ = [ + "Flower", + "Fwd", + "Bwd", +] From 18c838eac2496708dcc45d6865ec2c0721bb7f24 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 23 Nov 2023 11:08:46 +0100 Subject: [PATCH 2/2] Automate FDS reference doc generation (#2562) --- datasets/dev/build-flwr-datasets-docs.sh | 30 +++++++++ .../source/_templates/autosummary/class.rst | 33 ++++++++++ .../source/_templates/autosummary/module.rst | 66 +++++++++++++++++++ datasets/doc/source/conf.py | 34 ++++++++++ datasets/doc/source/index.rst | 10 ++- datasets/doc/source/ref-api-flwr-datasets.rst | 27 -------- dev/build-docs.sh | 3 +- 7 files changed, 171 insertions(+), 32 deletions(-) create mode 100755 datasets/dev/build-flwr-datasets-docs.sh create mode 100644 datasets/doc/source/_templates/autosummary/class.rst create mode 100644 datasets/doc/source/_templates/autosummary/module.rst delete mode 100644 datasets/doc/source/ref-api-flwr-datasets.rst diff --git a/datasets/dev/build-flwr-datasets-docs.sh b/datasets/dev/build-flwr-datasets-docs.sh new file mode 100755 index 000000000000..dc3cd979d5c8 --- /dev/null +++ b/datasets/dev/build-flwr-datasets-docs.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Generating the docs, rename and move the files such that the meet the convention used in Flower. +# Note that it involves two runs of sphinx-build that are necessary. +# The first run generates the .rst files (and the html files that are discarded) +# The second time it is run after the files are renamed and moved to the correct place. It generates the final htmls. + +set -e + +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../doc + +# Remove the old docs from source/ref-api +REF_API_DIR="source/ref-api" +if [[ -d "$REF_API_DIR" ]]; then + + echo "Removing ${REF_API_DIR}" + rm -r ${REF_API_DIR} +fi + +# Remove the old html files +if [[ -d build ]]; then + echo "Removing ./build" + rm -r build +fi + +# Docs generation: Generate new rst files +# It starts at the __init__ in the main directory and recursively generated the documentation for the +# specified classes/modules/packages specified in __all__. +# Note if a package cannot be reach via the recursive traversal, even if it has __all__, it won't be documented. +echo "Generating the docs based on only the functionality given in the __all__." +sphinx-build -M html source build diff --git a/datasets/doc/source/_templates/autosummary/class.rst b/datasets/doc/source/_templates/autosummary/class.rst new file mode 100644 index 000000000000..b4b35789bc6f --- /dev/null +++ b/datasets/doc/source/_templates/autosummary/class.rst @@ -0,0 +1,33 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :show-inheritance: + :inherited-members: + + {% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {% if item != "__init__" %} + ~{{ name }}.{{ item }} + {% endif %} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/datasets/doc/source/_templates/autosummary/module.rst b/datasets/doc/source/_templates/autosummary/module.rst new file mode 100644 index 000000000000..571db198d27c --- /dev/null +++ b/datasets/doc/source/_templates/autosummary/module.rst @@ -0,0 +1,66 @@ +{{ name | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: Module Attributes + + .. autosummary:: + :toctree: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :toctree: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + :template: autosummary/class.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :toctree: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. rubric:: Modules + +.. autosummary:: + :toctree: + :template: autosummary/module.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index 4fccaf0ef084..32baa6dd1471 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -61,8 +61,42 @@ "nbsphinx", ] +# Generate .rst files autosummary_generate = True +# Document ONLY the objects from __all__ (present in __init__ files). +# It will be done recursively starting from flwr_dataset.__init__ +# It's controlled in the index.rst file. +autosummary_ignore_module_all = False + +# Each class and function docs start with the path to it +# Make the flwr_datasets.federated_dataset.FederatedDataset appear as FederatedDataset +# The full name is still at the top of the page +add_module_names = False + +def find_test_modules(package_path): + """Go through the python files and exclude every *_test.py file.""" + full_path_modules = [] + for root, dirs, files in os.walk(package_path): + for file in files: + if file.endswith('_test.py'): + # Construct the module path relative to the package directory + full_path = os.path.join(root, file) + relative_path = os.path.relpath(full_path, package_path) + # Convert file path to dotted module path + module_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + full_path_modules.append(module_path) + modules = [] + for full_path_module in full_path_modules: + parts = full_path_module.split('.') + for i in range(len(parts)): + modules.append('.'.join(parts[i:])) + return modules + +# Stop from documenting the *_test.py files. +# That's the only way to do that in autosummary (make the modules as mock_imports). +autodoc_mock_imports = find_test_modules(os.path.abspath("../../")) + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index 7b19624b341a..ae7e7259f504 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -38,11 +38,15 @@ References Information-oriented API reference and other reference material. -.. toctree:: - :maxdepth: 2 +.. autosummary:: + :toctree: ref-api + :template: autosummary/module.rst :caption: API reference + :recursive: + + flwr_datasets + - ref-api-flwr-datasets Main features ------------- diff --git a/datasets/doc/source/ref-api-flwr-datasets.rst b/datasets/doc/source/ref-api-flwr-datasets.rst deleted file mode 100644 index 2e6a9e731add..000000000000 --- a/datasets/doc/source/ref-api-flwr-datasets.rst +++ /dev/null @@ -1,27 +0,0 @@ -flwr\_datasets (Python API reference) -====================== - -Federated Dataset ------------------ -.. autoclass:: flwr_datasets.federated_dataset.FederatedDataset - :members: - - -partitioner ------------ - -.. automodule:: flwr_datasets.partitioner - - -Partitioner ------------ - -.. autoclass:: flwr_datasets.partitioner.Partitioner - :members: - - -IID Partitioner ---------------- - -.. autoclass:: flwr_datasets.partitioner.IidPartitioner - :members: diff --git a/dev/build-docs.sh b/dev/build-docs.sh index 0c913c6fc1d8..45a4dfca0adf 100755 --- a/dev/build-docs.sh +++ b/dev/build-docs.sh @@ -13,8 +13,7 @@ cd examples/doc make docs cd $ROOT -cd datasets/doc -make docs +./datasets/dev/build-flwr-datasets-docs.sh cd $ROOT cd doc