diff --git a/docs/about/features_index/taskrunner.rst b/docs/about/features_index/taskrunner.rst
index 9b3874d9a8..034163de02 100644
--- a/docs/about/features_index/taskrunner.rst
+++ b/docs/about/features_index/taskrunner.rst
@@ -7,20 +7,16 @@
Task Runner API
================
+Let's take a deeper dive into the Task Runner API. If you haven't already, we suggest checking out the :ref:`quick_start` for a primer on doing a simple experiment on a single node.
-An overview of this workflow is shown below.
+The steps to transition from a local experiment to a distributed federation can be understood best with the following diagram.
.. figure:: ../../images/openfl_flow.png
-.. centered:: Overview of the Aggregator-Based Workflow
+.. centered:: Overview of a Task Runner experiment distributed across multiple nodes
+:
-There are two ways to run federation without Director:
-
-- `Bare Metal Approach`_
-- `Docker Approach`_
-
-
-This workflow uses short-lived components in a federation, which is terminated when the experiment is finished. The components are as follows:
+The Task Runner API uses short-lived components in a federation, which is terminated when the experiment is finished. The components are as follows:
- The *Collaborator* uses a local dataset to train a global model and the *Aggregator* receives model updates from *Collaborators* and aggregates them to create the new global model.
- The *Aggregator* is framework-agnostic, while the *Collaborator* can use any deep learning frameworks, such as `TensorFlow `_\* \ or `PyTorch `_\*\.
@@ -504,8 +500,8 @@ In fact, the :code:`get_model()` method returns a **TaskRunner** object loaded w
.. _running_the_federation_docker:
-Docker Approach
----------------
+Running inside Docker
+---------------------
There are two ways you can run |productName| with Docker\*\.
@@ -571,4 +567,4 @@ Option 2: Deploy Your Workspace in a Docker Container
.. toctree
.. overview.how_can_intel_protect_federated_learning
-.. overview.what_is_intel_federated_learning
\ No newline at end of file
+.. overview.what_is_intel_federated_learning
diff --git a/docs/conf.py b/docs/conf.py
index 5fd63f6114..8e0983bcbd 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -35,6 +35,7 @@
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx-prompt',
+ 'sphinx_copybutton',
'sphinx_substitution_extensions',
'sphinx.ext.ifconfig',
'sphinxcontrib.mermaid',
diff --git a/docs/get_started/examples.rst b/docs/get_started/examples.rst
index b083455a75..4b9ad39f66 100644
--- a/docs/get_started/examples.rst
+++ b/docs/get_started/examples.rst
@@ -7,29 +7,46 @@
Examples for Running a Federation
=================================
-|productName| currently offers three ways to set up and run experiments with a federation:
-the Task Runner API, the Interactive API, and the experimental workflow interface.
+|productName| currently offers four ways to set up and run experiments with a federation:
+the Task Runner API, Python Native API, the Interactive API, and the Workflow API.
+the Task Runner API is advised for production scenarios where the workload needs to be verified prior to execution, whereas the python native API provides a clean python interface on top of it intended for simulation purposes.
The Interactive API introduces a convenient way to set up a federation and brings “long-lived” components in a federation (“Director” and “Envoy”),
-while the Task Runner API workflow is advised for scenarios where the workload needs to be verified prior to execution. In contrast, the experimental workflow interface
-is introduce to provide significant flexility to researchers and developers in the construction of federated learning experiments.
+while the Task Runner API workflow is advised for scenarios where the workload needs to be verified prior to execution. In contrast, the currently experimental Workflow API
+is introduced to provide significant flexility to researchers and developers in the construction of federated learning experiments.
+
+As OpenFL nears it's 2.0 release, we expect to consolidate these APIs and make the Workflow API the primary interface going forward. See our `roadmap `_ for more details.
-------------------------
Task Runner API
-------------------------
-Formulate the experiment as a series of tasks, or a flow.
+Formulate the experiment as a series of tasks coordinated by a Federated Learning Plan
-See :ref:`taskrunner_pytorch_mnist`
+See :ref:`running_the_task_runner`
.. toctree::
:hidden:
:maxdepth: 1
- examples/taskrunner_pytorch_mnist
+ :ref:`running_the_task_runner`
+
+-------------------------
+Python Native API
+-------------------------
+Intended for quick simulation purposes
+
+See :ref:`python_native_pytorch_mnist`
+
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+
+ examples/python_native_pytorch_mnist
+
-------------------------
Interactive API
-------------------------
-Setup long-lived components to run many experiments in series.
+Setup long-lived components to run many experiments
See :ref:`interactive_tensorflow_mnist`
@@ -55,4 +72,4 @@ See :ref:`workflowinterface_pytorch_mnist`
.. note::
- Please visit `repository `_ for a full list of tutorials
\ No newline at end of file
+ Please visit `repository `_ for a full list of tutorials
diff --git a/docs/get_started/examples/taskrunner_pytorch_mnist.rst b/docs/get_started/examples/python_native_pytorch_mnist.rst
similarity index 94%
rename from docs/get_started/examples/taskrunner_pytorch_mnist.rst
rename to docs/get_started/examples/python_native_pytorch_mnist.rst
index eed7161ea4..8105ad495c 100644
--- a/docs/get_started/examples/taskrunner_pytorch_mnist.rst
+++ b/docs/get_started/examples/python_native_pytorch_mnist.rst
@@ -1,13 +1,13 @@
.. # Copyright (C) 2020-2023 Intel Corporation
.. # SPDX-License-Identifier: Apache-2.0
-.. _taskrunner_pytorch_mnist:
+.. _python_native_pytorch_mnist:
==========================================
-Task Runner API: Federated PyTorch MNIST
+Python Native API: Federated PyTorch MNIST
==========================================
-In this tutorial, we will set up a federation and train a basic PyTorch model on the MNIST dataset using the task runner API.
+In this tutorial, we will set up a federation and train a basic PyTorch model on the MNIST dataset using the Python Native API.
See `full notebook `_.
.. note::
@@ -170,4 +170,4 @@ If we want to pass in custom plan settings, we can easily do that with the overr
.. code-block:: python
#Save final model
- final_fl_model.save_native('final_pytorch_model')
\ No newline at end of file
+ final_fl_model.save_native('final_pytorch_model')
diff --git a/docs/get_started/quickstart.rst b/docs/get_started/quickstart.rst
index c14894f805..52e664177e 100644
--- a/docs/get_started/quickstart.rst
+++ b/docs/get_started/quickstart.rst
@@ -8,32 +8,361 @@ Quick Start
=====================
|productName| has a variety of APIs to choose from when setting up and running a federation.
-In this quick start guide, we will demonstrate how to run a simple federated learning example using the Task Runner API and Hello Federation script
+In this quick start guide, we will demonstrate how to run a simple federated learning example using the Task Runner API.
-.. note::
- The example used in this section is designed primarily to demonstrate functionality of the package and its components. It is not the recommended method for running a real world federation.
- See :ref:`openfl_examples` for details.
+.. _creating_a_federation:
-.. _hello_federation:
+********************************
+Creating a federation in 5 steps
+********************************
-*********************************
-Hello Federation
-*********************************
-.. note::
+To begin we recommend installing OpenFL inside a python virtual environment. This can be done with the following:
- Ensure you have installed the |productName| package.
+.. code-block:: console:
+
+ pip install virtualenv
+ virtualenv ~/openfl-quickstart
+ source ~/openfl-quickstart/bin/activate
+ pip install openfl
- See :ref:`install_package` for details.
-We will use the `"Hello Federation" python script `_ to quickly create a federation (an aggregator node and two collaborator nodes) to test the project pipeline.
+Now you're ready to run your first federation! Copying these commands to your terminal will run a simple federation with an aggregator and two collaborators all on your local machine. These commands can be broken down into five steps, which you can read more about `here <../about/features_index/taskrunner.html#step-1-create-a-workspace>`_
-.. literalinclude:: ../../tests/github/test_hello_federation.py
- :language: python
+1. Setup Federation Workspace & Certificate Authority (CA) for Secure Communication
+2. Setup Aggregator & Initialize Federation Plan + Model
+3. Setup Collaborator 1
+4. Setup Collaborator 2
+5. Run the Federation
-Run the script
+.. code-block:: console
+
+ ############################################################################################
+ # Step 1: Setup Federation Workspace & Certificate Authority (CA) for Secure Communication #
+ ############################################################################################
+
+ # Generate an OpenFL Workspace. This example will train a pytorch
+ # CNN model on the MNIST dataset
+ fx workspace create --template torch_cnn_mnist --prefix my_workspace
+ cd my_workspace
+
+ # This will create a certificate authority (CA), so the participants communicate over a secure TLS Channel
+ fx workspace certify
+
+ #################################################################
+ # Step 2: Setup Aggregator & Initialize Federation Plan + Model #
+ #################################################################
+
+ # Generate a Certificate Signing Request (CSR) for the Aggregator
+ fx aggregator generate-cert-request
+
+ # The CA signs the aggregator's request, which is now available in the workspace
+ fx aggregator certify --silent
+
+ # Initialize FL Plan and Model Weights for the Federation
+ fx plan initialize
+
+ ################################
+ # Step 3: Setup Collaborator 1 #
+ ################################
+
+ # Create a collaborator named "collaborator1" that will use data path "1"
+ fx collaborator create -n collaborator1 -d 1
+
+ # Generate a CSR for collaborator1
+ fx collaborator generate-cert-request -n collaborator1
+
+ # The CA signs collaborator1's certificate
+ fx collaborator certify -n collaborator1 --silent
+
+ ################################
+ # Step 4: Setup Collaborator 2 #
+ ################################
+
+ # Create a collaborator named "collaborator2" that will use data path "2"
+ fx collaborator create -n collaborator2 -d 2
+
+ # Generate a CSR for collaborator2
+ fx collaborator generate-cert-request -n collaborator2
+
+ # The CA signs collaborator2's certificate
+ fx collaborator certify -n collaborator2 --silent
+
+ ##############################
+ # Step 5. Run the Federation #
+ ##############################
+
+ # Run the Aggregator
+ fx aggregator start &
+
+ # Run Collaborator 1
+ fx collaborator start -n collaborator1 &
+
+ # Run Collaborator 2
+ fx collaborator start -n collaborator2
+
+ echo "Congratulations! You've run your first federation with OpenFL"
+
+
+You should see this output at the end of the experiment:
.. code-block:: console
- $ python test_hello_federation.py
\ No newline at end of file
+ INFO Starting round 9... aggregator.py:897
+ [15:36:28] INFO Waiting for tasks... collaborator.py:178
+ INFO Sending tasks to collaborator collaborator2 for round 9 aggregator.py:329
+ INFO Received the following tasks: [name: "aggregated_model_validation" collaborator.py:143
+ , name: "train"
+ , name: "locally_tuned_model_validation"
+ ]
+ [15:36:30] METRIC Round 9, collaborator collaborator2 is sending metric for task aggregated_model_validation: accuracy 0.983597 collaborator.py:415
+ [15:36:31] INFO Collaborator collaborator2 is sending task results for aggregated_model_validation, round 9 aggregator.py:520
+ METRIC Round 9, collaborator validate_agg aggregated_model_validation result accuracy: 0.983597 aggregator.py:559
+ [15:36:31] INFO Run 0 epoch of 9 round runner_pt.py:148
+ [15:36:31] INFO Waiting for tasks... collaborator.py:178
+ INFO Sending tasks to collaborator collaborator1 for round 9 aggregator.py:329
+ INFO Received the following tasks: [name: "aggregated_model_validation" collaborator.py:143
+ , name: "train"
+ , name: "locally_tuned_model_validation"
+ ]
+ [15:36:33] METRIC Round 9, collaborator collaborator1 is sending metric for task aggregated_model_validation: accuracy 0.981000 collaborator.py:415
+ [15:36:34] INFO Collaborator collaborator1 is sending task results for aggregated_model_validation, round 9 aggregator.py:520
+ METRIC Round 9, collaborator validate_agg aggregated_model_validation result accuracy: 0.981000 aggregator.py:559
+ [15:36:34] INFO Run 0 epoch of 9 round runner_pt.py:148
+ [15:36:34] METRIC Round 9, collaborator collaborator2 is sending metric for task train: cross_entropy 0.059750 collaborator.py:415
+ [15:36:35] INFO Collaborator collaborator2 is sending task results for train, round 9 aggregator.py:520
+ METRIC Round 9, collaborator metric train result cross_entropy: 0.059750 aggregator.py:559
+ [15:36:35] METRIC Round 9, collaborator collaborator2 is sending metric for task locally_tuned_model_validation: accuracy 0.979596 collaborator.py:415
+ INFO Collaborator collaborator2 is sending task results for locally_tuned_model_validation, round 9 aggregator.py:520
+ METRIC Round 9, collaborator validate_local locally_tuned_model_validation result accuracy: 0.979596 aggregator.py:559
+ INFO Waiting for tasks... collaborator.py:178
+ [15:36:37] METRIC Round 9, collaborator collaborator1 is sending metric for task train: cross_entropy 0.019203 collaborator.py:415
+ [15:36:38] INFO Collaborator collaborator1 is sending task results for train, round 9 aggregator.py:520
+ METRIC Round 9, collaborator metric train result cross_entropy: 0.019203 aggregator.py:559
+ [15:36:38] METRIC Round 9, collaborator collaborator1 is sending metric for task locally_tuned_model_validation: accuracy 0.977600 collaborator.py:415
+ INFO Collaborator collaborator1 is sending task results for locally_tuned_model_validation, round 9 aggregator.py:520
+ METRIC Round 9, collaborator validate_local locally_tuned_model_validation result accuracy: 0.977600 aggregator.py:559
+ METRIC Round 9, aggregator: train cross_entropy: 0.039476
+ [15:36:39] METRIC Round 9, aggregator: aggregated_model_validation accuracy: 0.982298
+ METRIC Round 9: saved the best model with score 0.982298 aggregator.py:854
+ METRIC Round 9, aggregator: locally_tuned_model_validation aggregator.py:838
+ accuracy:
+ 0.978598
+ INFO Saving round 10 model... aggregator.py:890
+ INFO Experiment Completed. Cleaning up... aggregator.py:895
+ [15:36:39] INFO Waiting for tasks... collaborator.py:178
+ INFO Sending signal to collaborator collaborator1 to shutdown... aggregator.py:283
+ INFO End of Federation reached. Exiting... collaborator.py:150
+
+ ✔ OK
+ [15:36:46] INFO Waiting for tasks... collaborator.py:178
+ [15:36:46] INFO Sending signal to collaborator collaborator2 to shutdown... aggregator.py:283
+ INFO End of Federation reached. Exiting... collaborator.py:150
+
+ ✔ OK
+
+ Congratulations! You've run your first federation with OpenFL
+
+***************************
+Working with your own model
+***************************
+
+Now that you've run your first federation, let's see how to replace the model used in the federation. After copying in the text above, you should be in the :code:`my_workspace` directory. Every workspace has a :code:`src` directory that contains the Task Runner, an OpenFL interface that defines the deep learning model, as well as the training and validation functions that will run on that model. In this case, the Task Runner is defined in :code:`src/taskrunner.py`. After opening it you'll see the following:
+
+.. code-block:: python
+
+ class PyTorchCNN(PyTorchTaskRunner):
+ """
+ Simple CNN for classification.
+
+ PyTorchTaskRunner inherits from nn.module, so you can define your model
+ in the same way that you would for PyTorch
+ """
+
+ def __init__(self, device='cpu', **kwargs):
+ """Initialize.
+
+ Args:
+ device: The hardware device to use for training (Default = "cpu")
+ **kwargs: Additional arguments to pass to the function
+
+ """
+ super().__init__(device=device, **kwargs)
+
+ ####################################
+ # Your model goes here #
+ ####################################
+ self.conv1 = nn.Conv2d(1, 20, 2, 1)
+ self.conv2 = nn.Conv2d(20, 50, 5, 1)
+ self.fc1 = nn.Linear(800, 500)
+ self.fc2 = nn.Linear(500, 10)
+ self.to(device)
+ ####################################
+
+ ######################################################################
+ # Your optimizer goes here #
+ # #
+ # `self.optimizer` must be set for optimizer weights to be federated #
+ ######################################################################
+ self.optimizer = optim.Adam(self.parameters(), lr=1e-4)
+
+ # Set the loss function
+ self.loss_fn = F.cross_entropy
+
+
+ def forward(self, x):
+ """
+ Forward pass of the model.
+
+ Args:
+ x: Data input to the model for the forward pass
+ """
+ x = F.relu(self.conv1(x))
+ x = F.max_pool2d(x, 2, 2)
+ x = F.relu(self.conv2(x))
+ x = F.max_pool2d(x, 2, 2)
+ x = x.view(-1, 800)
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+:code:`PyTorchTaskRunner` inherits from :code:`nn.module`, so changing your deep learning model is as easy as modifying the network layers (i.e. :code:`self.conv1`, etc.) into the :code:`__init__` function, and then defining your :code:`forward` function. You'll notice that unlike PyTorch, the optimizer is also defined in this :code:`__init__` function. This is so the model AND optimizer weights can be distributed as part of the federation.
+
+******************************************
+Defining your own train and validate tasks
+******************************************
+
+If you continue scrolling down in :code:`src/taskrunner.py`, you'll see two functions: :code:`train_` and :code:`validate_`. These are the primary tasks performed by the collaborators that have access to local data.
+
+.. code-block:: python
+
+ def train_(self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]) -> Metric:
+ """
+ Train single epoch.
+
+ Override this function in order to use custom training.
+
+ Args:
+ train_dataloader: Train dataset batch generator. Yields (samples, targets) tuples of
+ size = `self.data_loader.batch_size`.
+ Returns:
+ Metric: An object containing name and np.ndarray value.
+ """
+ losses = []
+ for data, target in train_dataloader:
+ data, target = data.to(self.device), target.to(self.device)
+ self.optimizer.zero_grad()
+ output = self(data)
+ loss = self.loss_fn(output, target)
+ loss.backward()
+ self.optimizer.step()
+ losses.append(loss.detach().cpu().numpy())
+ loss = np.mean(losses)
+ return Metric(name=self.loss_fn.__name__, value=np.array(loss))
+
+
+ def validate_(self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]) -> Metric:
+ """
+ Perform validation on PyTorch Model
+
+ Override this function for your own custom validation function
+
+ Args:
+ validation_dataloader: Validation dataset batch generator. Yields (samples, targets) tuples
+ Returns:
+ Metric: An object containing name and np.ndarray value
+ """
+
+ total_samples = 0
+ val_score = 0
+ with torch.no_grad():
+ for data, target in validation_dataloader:
+ samples = target.shape[0]
+ total_samples += samples
+ data, target = data.to(self.device), target.to(self.device, dtype=torch.int64)
+ output = self(data)
+ # get the index of the max log-probability
+ pred = output.argmax(dim=1)
+ val_score += pred.eq(target).sum().cpu().numpy()
+
+ accuracy = val_score / total_samples
+ return Metric(name='accuracy', value=np.array(accuracy))
+
+Each function is passed a dataloader, and returns a :code:`Metric` associated with that task. In this example the :code:`train_` function returns the Cross Entropy Loss for an epoch, and the :code:`validate_` function returns the accuracy. You'll see these metrics reported when running the collaborator locally, and the aggregator will report the average metrics coming from all collaborators.
+
+*****************************
+Defining your own data loader
+*****************************
+
+Now let's look at the OpenFL :code:`PyTorchDataLoader` and see how by subclassing it we are able to split the MNIST dataset across collaborators for training. You'll find the following defined in :code:`src/dataloader.py`.
+
+
+.. code-block:: python
+
+ from openfl.federated import PyTorchDataLoader
+
+ class PyTorchMNISTInMemory(PyTorchDataLoader):
+ """PyTorch data loader for MNIST dataset."""
+
+ def __init__(self, data_path, batch_size, **kwargs):
+ """Instantiate the data object.
+
+ Args:
+ data_path: The file path to the data
+ batch_size: The batch size of the data loader
+ **kwargs: Additional arguments, passed to super
+ init and load_mnist_shard
+ """
+ super().__init__(batch_size, **kwargs)
+
+ num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
+ shard_num=int(data_path), **kwargs)
+
+ self.X_train = X_train
+ self.y_train = y_train
+ self.train_loader = self.get_train_loader()
+
+ self.X_valid = X_valid
+ self.y_valid = y_valid
+ self.val_loader = self.get_valid_loader()
+
+ self.num_classes = num_classes
+
+This example uses the classic MNIST dataset for digit recognition. For in-memory datasets, the :code:`data_path` is passed a number to determine which slice of the dataset the collaborator should receive. By initializing the :code:`train_loader` (:code:`self.train_loader = self.get_train_loader()`) and the :code:`val_loader` (:code:`self.val_loader = self.get_valid_loader()`), these dataloader will then be able to be passed into the :code:`train_` and :code:`validate_` functions defined above.
+
+***************************************
+Changing the number of federated rounds
+***************************************
+
+Now that we've seen how to change the code, let's explore the Federated Learning Plan (FL Plan). The plan, which is defined in :code:`plan/plan.yaml`, is used to configure everything about the federation that can't purely be expressed in python. This includes information like network connectivity details, how different components are configured, and how many rounds the federation should train. Different experiments may take more rounds to train depending on how similar data is between collaborators, the model, and the number of collaborators that participate. To tweak this parameter for your experiment, open :code:`plan/plan.yaml` and modify the following section:
+
+.. code-block:: yaml
+
+ aggregator:
+ settings:
+ best_state_path: save/torch_cnn_mnist_best.pbuf
+ db_store_rounds: 2
+ init_state_path: save/torch_cnn_mnist_init.pbuf
+ last_state_path: save/torch_cnn_mnist_last.pbuf
+ log_metric_callback:
+ template: src.utils.write_metric
+ rounds_to_train: 10 # Change this value to train for a different number of rounds
+ write_logs: true
+
+*****************************************************
+Starting a new federation after making custom changes
+*****************************************************
+
+Now that you've changed a few things, you can rerun the federation. Copying the below text will reinitialize your plan with new model weights, and relaunch the aggregator and two collaborators:
+
+.. code-block:: console
+
+ fx plan initialize
+ fx aggregator start &
+ fx collaborator start -n collaborator1 &
+ fx collaborator start -n collaborator2
+
+Well done! Now that you know the basics of using the Task Runner API to run OpenFL on a single node, check out some of the other :ref:`openfl_examples` for research purposes and in production.
diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt
index f9a47f5bb4..162e5499d6 100644
--- a/docs/requirements-docs.txt
+++ b/docs/requirements-docs.txt
@@ -3,7 +3,8 @@
sphinx-rtd-theme
sphinx-prompt
sphinx_substitution_extensions
+sphinx-copybutton
sphinxcontrib-mermaid
pygments>=2.7.4 # not directly required, pinned by Snyk to avoid a vulnerability
sphinx>=3.0.4 # not directly required, pinned by Snyk to avoid a vulnerability
-recommonmark
\ No newline at end of file
+recommonmark
diff --git a/openfl-workspace/torch_cnn_histology/src/pt_cnn.py b/openfl-workspace/torch_cnn_histology/src/pt_cnn.py
index 73127afc22..cfa7122243 100644
--- a/openfl-workspace/torch_cnn_histology/src/pt_cnn.py
+++ b/openfl-workspace/torch_cnn_histology/src/pt_cnn.py
@@ -35,17 +35,14 @@ def __init__(self, **kwargs):
self.num_classes = self.data_loader.num_classes
self.init_network(device=self.device, **kwargs)
- self._init_optimizer(lr=kwargs.get('lr'))
+ self._init_optimizer(lr=kwargs.get("lr"))
self.initialize_tensorkeys_for_functions()
def _init_optimizer(self, lr):
"""Initialize the optimizer."""
self.optimizer = optim.Adam(self.parameters(), lr=float(lr or 1e-3))
- def init_network(self,
- device,
- print_model=True,
- **kwargs):
+ def init_network(self, device, print_model=True, **kwargs):
"""Create the network (model).
Args:
@@ -54,9 +51,8 @@ def init_network(self,
**kwargs: Additional arguments to pass to the function
"""
- channel = self.data_loader.get_feature_shape()[
- 0] # (channel, dim1, dim2)
- conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2)
+ conv_kwargs = {"kernel_size": 3, "stride": 1, "padding": 1}
self.conv1 = nn.Conv2d(channel, 16, **conv_kwargs)
self.conv2 = nn.Conv2d(16, 32, **conv_kwargs)
self.conv3 = nn.Conv2d(32, 64, **conv_kwargs)
@@ -101,8 +97,9 @@ def forward(self, x):
x = self.fc2(x)
return x
- def validate(self, col_name, round_num, input_tensor_dict,
- use_tqdm=False, **kwargs):
+ def validate_task(
+ self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
+ ):
"""Validate.
Run validation of the model on the local data.
@@ -126,31 +123,34 @@ def validate(self, col_name, round_num, input_tensor_dict,
loader = self.data_loader.get_valid_loader()
if use_tqdm:
- loader = tqdm.tqdm(loader, desc='validate')
+ loader = tqdm.tqdm(loader, desc="validate")
with torch.no_grad():
for data, target in loader:
samples = target.shape[0]
total_samples += samples
- data, target = (torch.tensor(data).to(self.device),
- torch.tensor(target).to(self.device))
+ data, target = (
+ torch.tensor(data).to(self.device),
+ torch.tensor(target).to(self.device),
+ )
output = self(data)
# get the index of the max log-probability
pred = output.argmax(dim=1)
val_score += pred.eq(target).sum().cpu().numpy()
origin = col_name
- suffix = 'validate'
- if kwargs['apply'] == 'local':
- suffix += '_local'
+ suffix = "validate"
+ if kwargs["apply"] == "local":
+ suffix += "_local"
else:
- suffix += '_agg'
- tags = ('metric', suffix)
+ suffix += "_agg"
+ tags = ("metric", suffix)
# TODO figure out a better way to pass in metric for
# this pytorch validate function
output_tensor_dict = {
- TensorKey('acc', origin, round_num, True, tags):
- np.array(val_score / total_samples)
+ TensorKey("acc", origin, round_num, True, tags): np.array(
+ val_score / total_samples
+ )
}
# empty list represents metrics that should only be stored locally
@@ -162,4 +162,4 @@ def reset_opt_vars(self):
Resets the optimizer state variables.
"""
- self._init_optimizer(lr=self.optimizer.defaults.get('lr'))
+ self._init_optimizer(lr=self.optimizer.defaults.get("lr"))
diff --git a/openfl-workspace/torch_cnn_histology_gramine_ready/src/pt_cnn.py b/openfl-workspace/torch_cnn_histology_gramine_ready/src/pt_cnn.py
index 73127afc22..cfa7122243 100644
--- a/openfl-workspace/torch_cnn_histology_gramine_ready/src/pt_cnn.py
+++ b/openfl-workspace/torch_cnn_histology_gramine_ready/src/pt_cnn.py
@@ -35,17 +35,14 @@ def __init__(self, **kwargs):
self.num_classes = self.data_loader.num_classes
self.init_network(device=self.device, **kwargs)
- self._init_optimizer(lr=kwargs.get('lr'))
+ self._init_optimizer(lr=kwargs.get("lr"))
self.initialize_tensorkeys_for_functions()
def _init_optimizer(self, lr):
"""Initialize the optimizer."""
self.optimizer = optim.Adam(self.parameters(), lr=float(lr or 1e-3))
- def init_network(self,
- device,
- print_model=True,
- **kwargs):
+ def init_network(self, device, print_model=True, **kwargs):
"""Create the network (model).
Args:
@@ -54,9 +51,8 @@ def init_network(self,
**kwargs: Additional arguments to pass to the function
"""
- channel = self.data_loader.get_feature_shape()[
- 0] # (channel, dim1, dim2)
- conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2)
+ conv_kwargs = {"kernel_size": 3, "stride": 1, "padding": 1}
self.conv1 = nn.Conv2d(channel, 16, **conv_kwargs)
self.conv2 = nn.Conv2d(16, 32, **conv_kwargs)
self.conv3 = nn.Conv2d(32, 64, **conv_kwargs)
@@ -101,8 +97,9 @@ def forward(self, x):
x = self.fc2(x)
return x
- def validate(self, col_name, round_num, input_tensor_dict,
- use_tqdm=False, **kwargs):
+ def validate_task(
+ self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
+ ):
"""Validate.
Run validation of the model on the local data.
@@ -126,31 +123,34 @@ def validate(self, col_name, round_num, input_tensor_dict,
loader = self.data_loader.get_valid_loader()
if use_tqdm:
- loader = tqdm.tqdm(loader, desc='validate')
+ loader = tqdm.tqdm(loader, desc="validate")
with torch.no_grad():
for data, target in loader:
samples = target.shape[0]
total_samples += samples
- data, target = (torch.tensor(data).to(self.device),
- torch.tensor(target).to(self.device))
+ data, target = (
+ torch.tensor(data).to(self.device),
+ torch.tensor(target).to(self.device),
+ )
output = self(data)
# get the index of the max log-probability
pred = output.argmax(dim=1)
val_score += pred.eq(target).sum().cpu().numpy()
origin = col_name
- suffix = 'validate'
- if kwargs['apply'] == 'local':
- suffix += '_local'
+ suffix = "validate"
+ if kwargs["apply"] == "local":
+ suffix += "_local"
else:
- suffix += '_agg'
- tags = ('metric', suffix)
+ suffix += "_agg"
+ tags = ("metric", suffix)
# TODO figure out a better way to pass in metric for
# this pytorch validate function
output_tensor_dict = {
- TensorKey('acc', origin, round_num, True, tags):
- np.array(val_score / total_samples)
+ TensorKey("acc", origin, round_num, True, tags): np.array(
+ val_score / total_samples
+ )
}
# empty list represents metrics that should only be stored locally
@@ -162,4 +162,4 @@ def reset_opt_vars(self):
Resets the optimizer state variables.
"""
- self._init_optimizer(lr=self.optimizer.defaults.get('lr'))
+ self._init_optimizer(lr=self.optimizer.defaults.get("lr"))
diff --git a/openfl-workspace/torch_cnn_mnist/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist/plan/plan.yaml
index 47ba88b3f2..d85442e8e4 100644
--- a/openfl-workspace/torch_cnn_mnist/plan/plan.yaml
+++ b/openfl-workspace/torch_cnn_mnist/plan/plan.yaml
@@ -10,7 +10,7 @@ aggregator :
last_state_path : save/torch_cnn_mnist_last.pbuf
rounds_to_train : 10
log_metric_callback :
- template : src.mnist_utils.write_metric
+ template : src.utils.write_metric
collaborator :
@@ -22,7 +22,7 @@ collaborator :
data_loader :
defaults : plan/defaults/data_loader.yaml
- template : src.ptmnist_inmemory.PyTorchMNISTInMemory
+ template : src.dataloader.PyTorchMNISTInMemory
settings :
collaborator_count : 2
data_group_name : mnist
@@ -30,7 +30,7 @@ data_loader :
task_runner :
defaults : plan/defaults/task_runner.yaml
- template : src.pt_cnn.PyTorchCNN
+ template : src.taskrunner.PyTorchCNN
network :
defaults : plan/defaults/network.yaml
diff --git a/openfl-workspace/torch_cnn_mnist/src/mnist_utils.py b/openfl-workspace/torch_cnn_mnist/src/dataloader.py
similarity index 67%
rename from openfl-workspace/torch_cnn_mnist/src/mnist_utils.py
rename to openfl-workspace/torch_cnn_mnist/src/dataloader.py
index 16ee801b4d..4130ceafdb 100644
--- a/openfl-workspace/torch_cnn_mnist/src/mnist_utils.py
+++ b/openfl-workspace/torch_cnn_mnist/src/dataloader.py
@@ -1,31 +1,87 @@
-# Copyright (C) 2020-2021 Intel Corporation
+# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""You may copy this file as the starting point of your own model."""
-from logging import getLogger
-
-import numpy as np
-from torch.utils.tensorboard import SummaryWriter
+from openfl.federated import PyTorchDataLoader
from torchvision import datasets
from torchvision import transforms
+import numpy as np
+from logging import getLogger
logger = getLogger(__name__)
-writer = None
+class PyTorchMNISTInMemory(PyTorchDataLoader):
+ """PyTorch data loader for MNIST dataset."""
-def get_writer():
- """Create global writer object."""
- global writer
- if not writer:
- writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
+ def __init__(self, data_path, batch_size, **kwargs):
+ """Instantiate the data object.
+ Args:
+ data_path: The file path to the data
+ batch_size: The batch size of the data loader
+ **kwargs: Additional arguments, passed to super
+ init and load_mnist_shard
+ """
+ super().__init__(batch_size, **kwargs)
-def write_metric(node_name, task_name, metric_name, metric, round_number):
- """Write metric callback."""
- get_writer()
- writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
+ num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
+ shard_num=int(data_path), **kwargs
+ )
+
+ self.X_train = X_train
+ self.y_train = y_train
+ self.train_loader = self.get_train_loader()
+
+ self.X_valid = X_valid
+ self.y_valid = y_valid
+ self.val_loader = self.get_valid_loader()
+
+ self.num_classes = num_classes
+
+
+def load_mnist_shard(
+ shard_num, collaborator_count, categorical=False, channels_last=True, **kwargs
+):
+ """
+ Load the MNIST dataset.
+
+ Args:
+ shard_num (int): The shard to use from the dataset
+ collaborator_count (int): The number of collaborators in the
+ federation
+ categorical (bool): True = convert the labels to one-hot encoded
+ vectors (Default = True)
+ channels_last (bool): True = The input images have the channels
+ last (Default = True)
+ **kwargs: Additional parameters to pass to the function
+
+ Returns:
+ list: The input shape
+ int: The number of classes
+ numpy.ndarray: The training data
+ numpy.ndarray: The training labels
+ numpy.ndarray: The validation data
+ numpy.ndarray: The validation labels
+ """
+ num_classes = 10
+
+ (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
+ shard_num, collaborator_count, transform=transforms.ToTensor()
+ )
+
+ logger.info(f"MNIST > X_train Shape : {X_train.shape}")
+ logger.info(f"MNIST > y_train Shape : {y_train.shape}")
+ logger.info(f"MNIST > Train Samples : {X_train.shape[0]}")
+ logger.info(f"MNIST > Valid Samples : {X_valid.shape[0]}")
+
+ if categorical:
+ # convert class vectors to binary class matrices
+ y_train = one_hot(y_train, num_classes)
+ y_valid = one_hot(y_valid, num_classes)
+
+ return num_classes, X_train, y_train, X_valid, y_valid
def one_hot(labels, classes):
@@ -57,7 +113,7 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None):
2 tuples: (image, label) of the training, validation dataset
"""
train_data, val_data = (
- datasets.MNIST('data', train=train, download=True, transform=transform)
+ datasets.MNIST("data", train=train, download=True, transform=transform)
for train in (True, False)
)
X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels
@@ -72,44 +128,3 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None):
y_valid = y_valid_tot[shard_num::collaborator_count]
return (X_train, y_train), (X_valid, y_valid)
-
-
-def load_mnist_shard(shard_num, collaborator_count,
- categorical=False, channels_last=True, **kwargs):
- """
- Load the MNIST dataset.
-
- Args:
- shard_num (int): The shard to use from the dataset
- collaborator_count (int): The number of collaborators in the
- federation
- categorical (bool): True = convert the labels to one-hot encoded
- vectors (Default = True)
- channels_last (bool): True = The input images have the channels
- last (Default = True)
- **kwargs: Additional parameters to pass to the function
-
- Returns:
- list: The input shape
- int: The number of classes
- numpy.ndarray: The training data
- numpy.ndarray: The training labels
- numpy.ndarray: The validation data
- numpy.ndarray: The validation labels
- """
- num_classes = 10
-
- (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
- shard_num, collaborator_count, transform=transforms.ToTensor())
-
- logger.info(f'MNIST > X_train Shape : {X_train.shape}')
- logger.info(f'MNIST > y_train Shape : {y_train.shape}')
- logger.info(f'MNIST > Train Samples : {X_train.shape[0]}')
- logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}')
-
- if categorical:
- # convert class vectors to binary class matrices
- y_train = one_hot(y_train, num_classes)
- y_valid = one_hot(y_valid, num_classes)
-
- return num_classes, X_train, y_train, X_valid, y_valid
diff --git a/openfl-workspace/torch_cnn_mnist/src/pt_cnn.py b/openfl-workspace/torch_cnn_mnist/src/pt_cnn.py
deleted file mode 100644
index 9dc2d60afb..0000000000
--- a/openfl-workspace/torch_cnn_mnist/src/pt_cnn.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# Copyright (C) 2020-2021 Intel Corporation
-# SPDX-License-Identifier: Apache-2.0
-
-"""You may copy this file as the starting point of your own model."""
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import tqdm
-
-from openfl.federated import PyTorchTaskRunner
-from openfl.utilities import TensorKey
-
-
-def cross_entropy(output, target):
- """Binary cross-entropy metric.
-
- Args:
- output: The mode prediction
- target: The target (ground truth label)
-
- Returns:
- Binary cross-entropy with logits
-
- """
- return F.cross_entropy(input=output, target=target)
-
-
-class PyTorchCNN(PyTorchTaskRunner):
- """Simple CNN for classification."""
-
- def __init__(self, device='cpu', **kwargs):
- """Initialize.
-
- Args:
- data: The data loader class
- device: The hardware device to use for training (Default = "cpu")
- **kwargs: Additional arguments to pass to the function
-
- """
- super().__init__(device=device, **kwargs)
-
- self.num_classes = self.data_loader.num_classes
- self.init_network(device=self.device, **kwargs)
- self._init_optimizer()
- self.loss_fn = cross_entropy
- self.initialize_tensorkeys_for_functions()
-
- def _init_optimizer(self):
- """Initialize the optimizer."""
- self.optimizer = optim.Adam(self.parameters(), lr=1e-4)
-
- def init_network(self,
- device,
- print_model=True,
- pool_sqrkernel_size=2,
- conv_sqrkernel_size=5,
- conv1_channels_out=20,
- conv2_channels_out=50,
- fc2_insize=500,
- **kwargs):
- """Create the network (model).
-
- Args:
- device: The hardware device to use for training
- print_model (bool): Print the model topology (Default=True)
- pool_sqrkernel_size (int): Max pooling kernel size (Default=2),
- assumes square 2x2
- conv_sqrkernel_size (int): Convolutional filter size (Default=5),
- assumes square 5x5
- conv1_channels_out (int): Number of filters in first
- convolutional layer (Default=20)
- conv2_channels_out: Number of filters in second convolutional
- layer (Default=50)
- fc2_insize (int): Number of neurons in the
- fully-connected layer (Default = 500)
- **kwargs: Additional arguments to pass to the function
-
- FIXME: We are tracking only side lengths (rather than
- length and width) as we are assuming square
- shapes for feature and kernels.
- In order that all of the input and activation components are
- used (not cut off), we rely on a criterion: appropriate integers
- are divisible so that all casting to int perfomed below does no
- rounding (i.e. all int casting simply converts a float with '0'
- in the decimal part to an int.)
-
- (Note this criterion held for the original input sizes considered
- for this model: 28x28 and 32x32 when used with the default values
- above)
- """
- self.pool_sqrkernel_size = pool_sqrkernel_size
- channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2)
- self.conv1 = nn.Conv2d(channel, conv1_channels_out, conv_sqrkernel_size, 1)
-
- # perform some calculations to track the size of the single channel activations
- # channels are first for pytorch
- conv1_sqrsize_in = self.feature_shape[-1]
- conv1_sqrsize_out = conv1_sqrsize_in - (conv_sqrkernel_size - 1)
- # a pool operation happens after conv1 out
- # (note dependence on 'forward' function below)
- conv2_sqrsize_in = int(conv1_sqrsize_out / pool_sqrkernel_size)
-
- self.conv2 = nn.Conv2d(conv1_channels_out, conv2_channels_out, conv_sqrkernel_size, 1)
-
- # more tracking of single channel activation size
- conv2_sqrsize_out = conv2_sqrsize_in - (conv_sqrkernel_size - 1)
- # a pool operation happens after conv2 out
- # (note dependence on 'forward' function below)
- l0 = int(conv2_sqrsize_out / pool_sqrkernel_size)
- self.fc1_insize = l0 * l0 * conv2_channels_out
- self.fc1 = nn.Linear(self.fc1_insize, fc2_insize)
- self.fc2 = nn.Linear(fc2_insize, self.num_classes)
- if print_model:
- print(self)
- self.to(device)
-
- def forward(self, x):
- """Forward pass of the model.
-
- Args:
- x: Data input to the model for the forward pass
- """
- x = F.relu(self.conv1(x))
- pl = self.pool_sqrkernel_size
- x = F.max_pool2d(x, pl, pl)
- x = F.relu(self.conv2(x))
- x = F.max_pool2d(x, pl, pl)
- x = x.view(-1, self.fc1_insize)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return x
-
- def validate(self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs):
- """Validate.
-
- Run validation of the model on the local data.
-
- Args:
- col_name: Name of the collaborator
- round_num: What round is it
- input_tensor_dict: Required input tensors (for model)
- use_tqdm (bool): Use tqdm to print a progress bar (Default=True)
-
- Returns:
- global_output_dict: Tensors to send back to the aggregator
- local_output_dict: Tensors to maintain in the local TensorDB
-
- """
- self.rebuild_model(round_num, input_tensor_dict, validation=True)
- self.eval()
- val_score = 0
- total_samples = 0
-
- loader = self.data_loader.get_valid_loader()
- if use_tqdm:
- loader = tqdm.tqdm(loader, desc='validate')
-
- with torch.no_grad():
- for data, target in loader:
- samples = target.shape[0]
- total_samples += samples
- data, target = torch.tensor(data).to(
- self.device), torch.tensor(target).to(
- self.device, dtype=torch.int64)
- output = self(data)
- # get the index of the max log-probability
- pred = output.argmax(dim=1)
- val_score += pred.eq(target).sum().cpu().numpy()
-
- origin = col_name
- suffix = 'validate'
- if kwargs['apply'] == 'local':
- suffix += '_local'
- else:
- suffix += '_agg'
- tags = ('metric', suffix)
- # TODO figure out a better way to pass
- # in metric for this pytorch validate function
- output_tensor_dict = {
- TensorKey('acc', origin, round_num, True, tags):
- np.array(val_score / total_samples)
- }
-
- # Empty list represents metrics that should only be stored locally
- return output_tensor_dict, {}
-
- def reset_opt_vars(self):
- """Reset optimizer variables.
-
- Resets the optimizer state variables.
-
- """
- self._init_optimizer()
-
- def save_native(self, filepath):
- """
- Save model in a picked file specified by the filepath.
- Uses torch.save().
-
- Args:
- filepath (string) : Path to pickle file to be
- created by pt.save().
- Returns:
- None
- """
- torch.save(self, filepath)
diff --git a/openfl-workspace/torch_cnn_mnist/src/ptmnist_inmemory.py b/openfl-workspace/torch_cnn_mnist/src/ptmnist_inmemory.py
deleted file mode 100644
index 6ff9d6c5d1..0000000000
--- a/openfl-workspace/torch_cnn_mnist/src/ptmnist_inmemory.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (C) 2020-2021 Intel Corporation
-# SPDX-License-Identifier: Apache-2.0
-
-"""You may copy this file as the starting point of your own model."""
-
-from openfl.federated import PyTorchDataLoader
-from .mnist_utils import load_mnist_shard
-
-
-class PyTorchMNISTInMemory(PyTorchDataLoader):
- """PyTorch data loader for MNIST dataset."""
-
- def __init__(self, data_path, batch_size, **kwargs):
- """Instantiate the data object.
-
- Args:
- data_path: The file path to the data
- batch_size: The batch size of the data loader
- **kwargs: Additional arguments, passed to super
- init and load_mnist_shard
- """
- super().__init__(batch_size, **kwargs)
-
- # TODO: We should be downloading the dataset shard into a directory
- # TODO: There needs to be a method to ask how many collaborators and
- # what index/rank is this collaborator.
- # Then we have a way to automatically shard based on rank and size
- # of collaborator list.
-
- num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
- shard_num=int(data_path), **kwargs)
-
- self.X_train = X_train
- self.y_train = y_train
- self.train_loader = self.get_train_loader()
-
- self.X_valid = X_valid
- self.y_valid = y_valid
- self.val_loader = self.get_valid_loader()
-
- self.num_classes = num_classes
diff --git a/openfl-workspace/torch_cnn_mnist/src/taskrunner.py b/openfl-workspace/torch_cnn_mnist/src/taskrunner.py
new file mode 100644
index 0000000000..b68ce9b2a1
--- /dev/null
+++ b/openfl-workspace/torch_cnn_mnist/src/taskrunner.py
@@ -0,0 +1,120 @@
+# Copyright (C) 2020-2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+"""You may copy this file as the starting point of your own model."""
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from typing import Iterator
+from typing import Tuple
+
+from openfl.federated import PyTorchTaskRunner
+from openfl.utilities import Metric
+
+
+class PyTorchCNN(PyTorchTaskRunner):
+ """
+ Simple CNN for classification.
+
+ PyTorchTaskRunner inherits from nn.module, so you can define your model
+ in the same way that you would for PyTorch
+ """
+
+ def __init__(self, device="cpu", **kwargs):
+ """Initialize.
+
+ Args:
+ device: The hardware device to use for training (Default = "cpu")
+ **kwargs: Additional arguments to pass to the function
+
+ """
+ super().__init__(device=device, **kwargs)
+
+ # Define the model
+ self.conv1 = nn.Conv2d(1, 20, 2, 1)
+ self.conv2 = nn.Conv2d(20, 50, 5, 1)
+ self.fc1 = nn.Linear(800, 500)
+ self.fc2 = nn.Linear(500, 10)
+ self.to(device)
+
+ # `self.optimizer` must be set for optimizer weights to be federated
+ self.optimizer = optim.Adam(self.parameters(), lr=1e-4)
+
+ # Set the loss function
+ self.loss_fn = F.cross_entropy
+
+ def forward(self, x):
+ """
+ Forward pass of the model.
+
+ Args:
+ x: Data input to the model for the forward pass
+ """
+ x = F.relu(self.conv1(x))
+ x = F.max_pool2d(x, 2, 2)
+ x = F.relu(self.conv2(x))
+ x = F.max_pool2d(x, 2, 2)
+ x = x.view(-1, 800)
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+ def train_(
+ self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
+ ) -> Metric:
+ """
+ Train single epoch.
+
+ Override this function in order to use custom training.
+
+ Args:
+ train_dataloader: Train dataset batch generator. Yields (samples, targets) tuples of
+ size = `self.data_loader.batch_size`.
+ Returns:
+ Metric: An object containing name and np.ndarray value.
+ """
+ losses = []
+ for data, target in train_dataloader:
+ data, target = data.to(self.device), target.to(self.device)
+ self.optimizer.zero_grad()
+ output = self(data)
+ loss = self.loss_fn(output, target)
+ loss.backward()
+ self.optimizer.step()
+ losses.append(loss.detach().cpu().numpy())
+ loss = np.mean(losses)
+ return Metric(name=self.loss_fn.__name__, value=np.array(loss))
+
+ def validate_(
+ self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
+ ) -> Metric:
+ """
+ Perform validation on PyTorch Model
+
+ Override this function for your own custom validation function
+
+ Args:
+ validation_dataloader: Validation dataset batch generator.
+ Yields (samples, targets) tuples
+ Returns:
+ Metric: An object containing name and np.ndarray value
+ """
+
+ total_samples = 0
+ val_score = 0
+ with torch.no_grad():
+ for data, target in validation_dataloader:
+ samples = target.shape[0]
+ total_samples += samples
+ data, target = data.to(self.device), target.to(
+ self.device, dtype=torch.int64
+ )
+ output = self(data)
+ # get the index of the max log-probability
+ pred = output.argmax(dim=1)
+ val_score += pred.eq(target).sum().cpu().numpy()
+
+ accuracy = val_score / total_samples
+ return Metric(name="accuracy", value=np.array(accuracy))
diff --git a/openfl-workspace/torch_cnn_mnist/src/utils.py b/openfl-workspace/torch_cnn_mnist/src/utils.py
new file mode 100644
index 0000000000..6b6124e76a
--- /dev/null
+++ b/openfl-workspace/torch_cnn_mnist/src/utils.py
@@ -0,0 +1,21 @@
+# Copyright (C) 2020-2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+"""You may copy this file as the starting point of your own utilities."""
+
+from torch.utils.tensorboard import SummaryWriter
+
+writer = None
+
+
+def get_writer():
+ """Create global writer object."""
+ global writer
+ if not writer:
+ writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
+
+
+def write_metric(node_name, task_name, metric_name, metric, round_number):
+ """Write metric callback."""
+ get_writer()
+ writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
diff --git a/openfl-workspace/torch_llm_horovod/src/pt_model.py b/openfl-workspace/torch_llm_horovod/src/pt_model.py
index c4abbdff5f..5afad1bffc 100644
--- a/openfl-workspace/torch_llm_horovod/src/pt_model.py
+++ b/openfl-workspace/torch_llm_horovod/src/pt_model.py
@@ -117,7 +117,7 @@ def launch_horovod(
)
return result
- def validate(
+ def validate_task(
self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
):
"""Validate.
@@ -180,7 +180,7 @@ def validate(
# Empty list represents metrics that should only be stored locally
return output_tensor_dict, {}
- def train_batches(
+ def train_task(
self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs
):
"""Train batches.
diff --git a/openfl-workspace/torch_unet_kvasir/src/fed_unet_runner.py b/openfl-workspace/torch_unet_kvasir/src/fed_unet_runner.py
index 97ac218048..362f1f53db 100644
--- a/openfl-workspace/torch_unet_kvasir/src/fed_unet_runner.py
+++ b/openfl-workspace/torch_unet_kvasir/src/fed_unet_runner.py
@@ -89,7 +89,7 @@ def forward(self, x):
x = torch.sigmoid(x)
return x
- def validate(self, col_name, round_num, input_tensor_dict, use_tqdm=True, **kwargs):
+ def validate_task(self, col_name, round_num, input_tensor_dict, use_tqdm=True, **kwargs):
"""Run validation of the model on the local data.
Args:
diff --git a/openfl-workspace/torch_unet_kvasir_gramine_ready/src/fed_unet_runner.py b/openfl-workspace/torch_unet_kvasir_gramine_ready/src/fed_unet_runner.py
index 97ac218048..362f1f53db 100644
--- a/openfl-workspace/torch_unet_kvasir_gramine_ready/src/fed_unet_runner.py
+++ b/openfl-workspace/torch_unet_kvasir_gramine_ready/src/fed_unet_runner.py
@@ -89,7 +89,7 @@ def forward(self, x):
x = torch.sigmoid(x)
return x
- def validate(self, col_name, round_num, input_tensor_dict, use_tqdm=True, **kwargs):
+ def validate_task(self, col_name, round_num, input_tensor_dict, use_tqdm=True, **kwargs):
"""Run validation of the model on the local data.
Args:
diff --git a/openfl-workspace/workspace/plan/defaults/tasks_torch.yaml b/openfl-workspace/workspace/plan/defaults/tasks_torch.yaml
index f41b0c3600..44486b23e0 100644
--- a/openfl-workspace/workspace/plan/defaults/tasks_torch.yaml
+++ b/openfl-workspace/workspace/plan/defaults/tasks_torch.yaml
@@ -1,19 +1,19 @@
aggregated_model_validation:
- function : validate
+ function : validate_task
kwargs :
apply : global
metrics :
- acc
locally_tuned_model_validation:
- function : validate
+ function : validate_task
kwargs :
apply: local
metrics :
- acc
train:
- function : train_batches
+ function : train_task
kwargs :
metrics :
- loss
diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py
index 5919398388..9fb3d00660 100644
--- a/openfl/component/collaborator/collaborator.py
+++ b/openfl/component/collaborator/collaborator.py
@@ -234,7 +234,7 @@ def do_task(self, task, round_number):
# New interactive python API
# New `Core` TaskRunner contains registry of tasks
func = self.task_runner.TASK_REGISTRY[func_name]
- self.logger.info('Using Interactive Python API')
+ self.logger.debug('Using Interactive Python API')
# So far 'kwargs' contained parameters read from the plan
# those are parameters that the eperiment owner registered for
@@ -250,7 +250,7 @@ def do_task(self, task, round_number):
# TaskRunner subclassing API
# Tasks are defined as methods of TaskRunner
func = getattr(self.task_runner, func_name)
- self.logger.info('Using TaskRunner subclassing API')
+ self.logger.debug('Using TaskRunner subclassing API')
global_output_tensor_dict, local_output_tensor_dict = func(
col_name=self.collaborator_name,
diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py
index 9d1cb699c1..be94abedfc 100644
--- a/openfl/federated/plan/plan.py
+++ b/openfl/federated/plan/plan.py
@@ -425,6 +425,9 @@ def get_task_runner(self, data_loader):
if self.runner_ is None:
self.runner_ = Plan.build(**defaults)
+ # Define task dependencies after taskrunner has been initialized
+ self.runner_.initialize_tensorkeys_for_functions()
+
return self.runner_
# Python interactive api
diff --git a/openfl/federated/task/runner_pt.py b/openfl/federated/task/runner_pt.py
index a95e942962..beb0fa4b3d 100644
--- a/openfl/federated/task/runner_pt.py
+++ b/openfl/federated/task/runner_pt.py
@@ -8,7 +8,7 @@
from typing import Tuple
import numpy as np
-import torch as pt
+import torch
import torch.nn as nn
import tqdm
@@ -22,13 +22,7 @@
class PyTorchTaskRunner(nn.Module, TaskRunner):
"""PyTorch Model class for Federated Learning."""
- def __init__(
- self,
- device: str = None,
- loss_fn=None,
- optimizer=None,
- **kwargs
- ):
+ def __init__(self, device: str = None, loss_fn=None, optimizer=None, **kwargs):
"""Initialize.
Args:
@@ -40,7 +34,7 @@ def __init__(
if device:
self.device = device
else:
- self.device = pt.device('cuda' if pt.cuda.is_available() else 'cpu')
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# This is a map of all the required tensors for each of the public
# functions in PyTorchTaskRunner
@@ -53,9 +47,9 @@ def __init__(
# overwrite attribute to account for one optimizer param (in every
# child model that does not overwrite get and set tensordict) that is
# not a numpy array
- self.tensor_dict_split_fn_kwargs.update({
- 'holdout_tensor_names': ['__opt_state_needed']
- })
+ self.tensor_dict_split_fn_kwargs.update(
+ {"holdout_tensor_names": ["__opt_state_needed"]}
+ )
def rebuild_model(self, round_num, input_tensor_dict, validation=False):
"""
@@ -64,18 +58,22 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False):
Returns:
None
"""
- if self.opt_treatment == 'RESET':
+ if self.opt_treatment == "RESET":
self.reset_opt_vars()
self.set_tensor_dict(input_tensor_dict, with_opt_vars=False)
- elif (self.training_round_completed
- and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation):
+ elif (
+ self.training_round_completed
+ and self.opt_treatment == "CONTINUE_GLOBAL"
+ and not validation
+ ):
self.set_tensor_dict(input_tensor_dict, with_opt_vars=True)
else:
self.set_tensor_dict(input_tensor_dict, with_opt_vars=False)
- def validate(self, col_name, round_num, input_tensor_dict,
- use_tqdm=False, **kwargs):
- """Validate.
+ def validate_task(
+ self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
+ ):
+ """Validate Task.
Run validation of the model on the local data.
@@ -93,46 +91,34 @@ def validate(self, col_name, round_num, input_tensor_dict,
self.rebuild_model(round_num, input_tensor_dict, validation=True)
self.eval()
self.to(self.device)
- val_score = 0
- total_samples = 0
loader = self.data_loader.get_valid_loader()
if use_tqdm:
- loader = tqdm.tqdm(loader, desc='validate')
+ loader = tqdm.tqdm(loader, desc="validate")
- with pt.no_grad():
- for data, target in loader:
- samples = target.shape[0]
- total_samples += samples
- data, target = pt.tensor(data).to(self.device), pt.tensor(
- target).to(self.device, dtype=pt.int64)
- output = self(data)
- # get the index of the max log-probability
- pred = output.argmax(dim=1, keepdim=True)
- target_categorical = target.argmax(dim=1, keepdim=True)
- val_score += pred.eq(target_categorical).sum().cpu().numpy()
+ metric = self.validate_(loader)
origin = col_name
- suffix = 'validate'
- if kwargs['apply'] == 'local':
- suffix += '_local'
+ suffix = "validate"
+ if kwargs["apply"] == "local":
+ suffix += "_local"
else:
- suffix += '_agg'
- tags = ('metric',)
+ suffix += "_agg"
+ tags = ("metric",)
tags = change_tags(tags, add_field=suffix)
# TODO figure out a better way to pass in metric for this pytorch
# validate function
output_tensor_dict = {
- TensorKey('acc', origin, round_num, True, tags):
- np.array(val_score / total_samples)
+ TensorKey(metric.name, origin, round_num, True, tags): metric.value
}
# Empty list represents metrics that should only be stored locally
return output_tensor_dict, {}
- def train_batches(self, col_name, round_num, input_tensor_dict,
- use_tqdm=False, epochs=1, **kwargs):
- """Train batches.
+ def train_task(
+ self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs
+ ):
+ """Train batches task.
Train the model on the requested number of batches.
@@ -152,51 +138,46 @@ def train_batches(self, col_name, round_num, input_tensor_dict,
self.train()
self.to(self.device)
for epoch in range(epochs):
- self.logger.info(f'Run {epoch} epoch of {round_num} round')
+ self.logger.info(f"Run {epoch} epoch of {round_num} round")
loader = self.data_loader.get_train_loader()
if use_tqdm:
- loader = tqdm.tqdm(loader, desc='train epoch')
- metric = self.train_epoch(loader)
+ loader = tqdm.tqdm(loader, desc="train epoch")
+ metric = self.train_(loader)
# Output metric tensors (scalar)
origin = col_name
- tags = ('trained',)
+ tags = ("trained",)
output_metric_dict = {
- TensorKey(
- metric.name, origin, round_num, True, ('metric',)
- ): metric.value
+ TensorKey(metric.name, origin, round_num, True, ("metric",)): metric.value
}
# output model tensors (Doesn't include TensorKey)
output_model_dict = self.get_tensor_dict(with_opt_vars=True)
global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(
- self.logger, output_model_dict,
- **self.tensor_dict_split_fn_kwargs
+ self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs
)
# Create global tensorkeys
global_tensorkey_model_dict = {
- TensorKey(tensor_name, origin, round_num, False, tags):
- nparray for tensor_name, nparray in global_model_dict.items()
+ TensorKey(tensor_name, origin, round_num, False, tags): nparray
+ for tensor_name, nparray in global_model_dict.items()
}
# Create tensorkeys that should stay local
local_tensorkey_model_dict = {
- TensorKey(tensor_name, origin, round_num, False, tags):
- nparray for tensor_name, nparray in local_model_dict.items()
+ TensorKey(tensor_name, origin, round_num, False, tags): nparray
+ for tensor_name, nparray in local_model_dict.items()
}
# The train/validate aggregated function of the next round will look
# for the updated model parameters.
# This ensures they will be resolved locally
next_local_tensorkey_model_dict = {
- TensorKey(tensor_name, origin, round_num + 1, False, ('model',)): nparray
- for tensor_name, nparray in local_model_dict.items()}
-
- global_tensor_dict = {
- **output_metric_dict,
- **global_tensorkey_model_dict
+ TensorKey(tensor_name, origin, round_num + 1, False, ("model",)): nparray
+ for tensor_name, nparray in local_model_dict.items()
}
+
+ global_tensor_dict = {**output_metric_dict, **global_tensorkey_model_dict}
local_tensor_dict = {
**local_tensorkey_model_dict,
- **next_local_tensorkey_model_dict
+ **next_local_tensorkey_model_dict,
}
# Update the required tensors if they need to be pulled from the
@@ -209,7 +190,7 @@ def train_batches(self, col_name, round_num, input_tensor_dict,
# these are only created after training occurs. A work around could
# involve doing a single epoch of training on random data to get the
# optimizer names, and then throwing away the model.
- if self.opt_treatment == 'CONTINUE_GLOBAL':
+ if self.opt_treatment == "CONTINUE_GLOBAL":
self.initialize_tensorkeys_for_functions(with_opt_vars=True)
# This will signal that the optimizer values are now present,
@@ -281,14 +262,14 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False):
# Grabbing keys from model's state_dict helps to confirm we have
# everything
for k in self.state_dict():
- new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device)
+ new_state[k] = torch.from_numpy(tensor_dict.pop(k)).to(device)
# set model state
self.load_state_dict(new_state)
if with_opt_vars:
# see if there is state to restore first
- if tensor_dict.pop('__opt_state_needed') == 'true':
+ if tensor_dict.pop("__opt_state_needed") == "true":
_set_optimizer_state(self.get_optimizer(), device, tensor_dict)
# sanity check that we did not record any state that was not used
@@ -310,8 +291,8 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs):
Returns:
list : [TensorKey]
"""
- if func_name == 'validate':
- local_model = 'apply=' + str(kwargs['apply'])
+ if func_name == "validate_task":
+ local_model = "apply=" + str(kwargs["apply"])
return self.required_tensorkeys_for_function[func_name][local_model]
else:
return self.required_tensorkeys_for_function[func_name]
@@ -334,59 +315,61 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False):
output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars)
global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(
- self.logger, output_model_dict,
- **self.tensor_dict_split_fn_kwargs
+ self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs
)
if not with_opt_vars:
global_model_dict_val = global_model_dict
local_model_dict_val = local_model_dict
else:
output_model_dict = self.get_tensor_dict(with_opt_vars=False)
- global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts(
- self.logger,
- output_model_dict,
- **self.tensor_dict_split_fn_kwargs
+ global_model_dict_val, local_model_dict_val = (
+ split_tensor_dict_for_holdouts(
+ self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs
+ )
)
- self.required_tensorkeys_for_function['train_batches'] = [
- TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',))
- for tensor_name in global_model_dict]
- self.required_tensorkeys_for_function['train_batches'] += [
- TensorKey(tensor_name, 'LOCAL', 0, False, ('model',))
- for tensor_name in local_model_dict]
-
- self.required_tensorkeys_for_function['train'] = [
- TensorKey(
- tensor_name, 'GLOBAL', 0, False, ('model',)
- ) for tensor_name in global_model_dict
+ self.required_tensorkeys_for_function["train_task"] = [
+ TensorKey(tensor_name, "GLOBAL", 0, False, ("model",))
+ for tensor_name in global_model_dict
+ ]
+ self.required_tensorkeys_for_function["train_task"] += [
+ TensorKey(tensor_name, "LOCAL", 0, False, ("model",))
+ for tensor_name in local_model_dict
+ ]
+
+ self.required_tensorkeys_for_function["train_task"] = [
+ TensorKey(tensor_name, "GLOBAL", 0, False, ("model",))
+ for tensor_name in global_model_dict
]
- self.required_tensorkeys_for_function['train'] += [
- TensorKey(
- tensor_name, 'LOCAL', 0, False, ('model',)
- ) for tensor_name in local_model_dict
+ self.required_tensorkeys_for_function["train_task"] += [
+ TensorKey(tensor_name, "LOCAL", 0, False, ("model",))
+ for tensor_name in local_model_dict
]
# Validation may be performed on local or aggregated (global) model,
# so there is an extra lookup dimension for kwargs
- self.required_tensorkeys_for_function['validate'] = {}
+ self.required_tensorkeys_for_function["validate_task"] = {}
# TODO This is not stateless. The optimizer will not be
- self.required_tensorkeys_for_function['validate']['apply=local'] = [
- TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',))
- for tensor_name in {
- **global_model_dict_val,
- **local_model_dict_val
- }]
- self.required_tensorkeys_for_function['validate']['apply=global'] = [
- TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',))
+ self.required_tensorkeys_for_function["validate_task"]["apply=local"] = [
+ TensorKey(tensor_name, "LOCAL", 0, False, ("trained",))
+ for tensor_name in {**global_model_dict_val, **local_model_dict_val}
+ ]
+ self.required_tensorkeys_for_function["validate_task"]["apply=global"] = [
+ TensorKey(tensor_name, "GLOBAL", 0, False, ("model",))
for tensor_name in global_model_dict_val
]
- self.required_tensorkeys_for_function['validate']['apply=global'] += [
- TensorKey(tensor_name, 'LOCAL', 0, False, ('model',))
+ self.required_tensorkeys_for_function["validate_task"]["apply=global"] += [
+ TensorKey(tensor_name, "LOCAL", 0, False, ("model",))
for tensor_name in local_model_dict_val
]
- def load_native(self, filepath, model_state_dict_key='model_state_dict',
- optimizer_state_dict_key='optimizer_state_dict', **kwargs):
+ def load_native(
+ self,
+ filepath,
+ model_state_dict_key="model_state_dict",
+ optimizer_state_dict_key="optimizer_state_dict",
+ **kwargs,
+ ):
"""
Load model and optimizer states from a pickled file specified by \
filepath. model_/optimizer_state_dict args can be specified if needed. \
@@ -394,7 +377,7 @@ def load_native(self, filepath, model_state_dict_key='model_state_dict',
Args:
filepath (string) : Path to pickle file created
- by pt.save().
+ by torch.save().
model_state_dict_key (string) : key for model state dict
in pickled file.
optimizer_state_dict_key (string) : key for optimizer state dict
@@ -404,20 +387,25 @@ def load_native(self, filepath, model_state_dict_key='model_state_dict',
Returns:
None
"""
- pickle_dict = pt.load(filepath)
+ pickle_dict = torch.load(filepath)
self.load_state_dict(pickle_dict[model_state_dict_key])
self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key])
- def save_native(self, filepath, model_state_dict_key='model_state_dict',
- optimizer_state_dict_key='optimizer_state_dict', **kwargs):
+ def save_native(
+ self,
+ filepath,
+ model_state_dict_key="model_state_dict",
+ optimizer_state_dict_key="optimizer_state_dict",
+ **kwargs,
+ ):
"""
Save model and optimizer states in a picked file specified by the \
filepath. model_/optimizer_state_dicts are stored in the keys provided. \
- Uses pt.save().
+ Uses torch.save().
Args:
filepath (string) : Path to pickle file to be
- created by pt.save().
+ created by torch.save().
model_state_dict_key (string) : key for model state dict
in pickled file.
optimizer_state_dict_key (string) : key for optimizer state
@@ -429,20 +417,21 @@ def save_native(self, filepath, model_state_dict_key='model_state_dict',
"""
pickle_dict = {
model_state_dict_key: self.state_dict(),
- optimizer_state_dict_key: self.optimizer.state_dict()
+ optimizer_state_dict_key: self.optimizer.state_dict(),
}
- pt.save(pickle_dict, filepath)
+ torch.save(pickle_dict, filepath)
def reset_opt_vars(self):
- """
- Reset optimizer variables.
+ """Reset optimizer variables.
- Resets the optimizer variables
+ Resets the optimizer state variables.
"""
pass
- def train_epoch(self, batch_generator: Iterator[Tuple[np.ndarray, np.ndarray]]) -> Metric:
+ def train_(
+ self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
+ ) -> Metric:
"""Train single epoch.
Override this function in order to use custom training.
@@ -454,9 +443,10 @@ def train_epoch(self, batch_generator: Iterator[Tuple[np.ndarray, np.ndarray]])
Metric: An object containing name and np.ndarray value.
"""
losses = []
- for data, target in batch_generator:
- data, target = pt.tensor(data).to(self.device), pt.tensor(
- target).to(self.device)
+ for data, target in train_dataloader:
+ data, target = torch.tensor(data).to(self.device), torch.tensor(target).to(
+ self.device
+ )
self.optimizer.zero_grad()
output = self(data)
loss = self.loss_fn(output=output, target=target)
@@ -466,6 +456,38 @@ def train_epoch(self, batch_generator: Iterator[Tuple[np.ndarray, np.ndarray]])
loss = np.mean(losses)
return Metric(name=self.loss_fn.__name__, value=np.array(loss))
+ def validate_(
+ self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
+ ) -> Metric:
+ """
+ Perform validation on PyTorch Model
+
+ Override this function for your own custom validation function
+
+ Args:
+ validation_data_loader: Validation dataset batch generator.
+ Yields (samples, targets) tuples.
+ Returns:
+ Metric: An object containing name and np.ndarray value
+ """
+
+ total_samples = 0
+ val_score = 0
+ with torch.no_grad():
+ for data, target in validation_dataloader:
+ samples = target.shape[0]
+ total_samples += samples
+ data, target = torch.tensor(data).to(self.device), torch.tensor(
+ target
+ ).to(self.device, dtype=torch.int64)
+ output = self(data)
+ # get the index of the max log-probability
+ pred = output.argmax(dim=1)
+ val_score += pred.eq(target).sum().cpu().numpy()
+
+ accuracy = val_score / total_samples
+ return Metric(name="accuracy", value=np.array(accuracy))
+
def _derive_opt_state_dict(opt_state_dict):
"""Separate optimizer tensors from the tensor dictionary.
@@ -482,18 +504,16 @@ def _derive_opt_state_dict(opt_state_dict):
derived_opt_state_dict = {}
# Determine if state is needed for this optimizer.
- if len(opt_state_dict['state']) == 0:
- derived_opt_state_dict['__opt_state_needed'] = 'false'
+ if len(opt_state_dict["state"]) == 0:
+ derived_opt_state_dict["__opt_state_needed"] = "false"
return derived_opt_state_dict
- derived_opt_state_dict['__opt_state_needed'] = 'true'
+ derived_opt_state_dict["__opt_state_needed"] = "true"
# Using one example state key, we collect keys for the corresponding
# dictionary value.
- example_state_key = opt_state_dict['param_groups'][0]['params'][0]
- example_state_subkeys = set(
- opt_state_dict['state'][example_state_key].keys()
- )
+ example_state_key = opt_state_dict["param_groups"][0]["params"][0]
+ example_state_subkeys = set(opt_state_dict["state"][example_state_key].keys())
# We assume that the state collected for all params in all param groups is
# the same.
@@ -501,52 +521,47 @@ def _derive_opt_state_dict(opt_state_dict):
# subkeys is a tensor depends only on the subkey.
# Using assert statements to break the routine if these assumptions are
# incorrect.
- for state_key in opt_state_dict['state'].keys():
- assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys())
+ for state_key in opt_state_dict["state"].keys():
+ assert example_state_subkeys == set(opt_state_dict["state"][state_key].keys())
for state_subkey in example_state_subkeys:
- assert (isinstance(
- opt_state_dict['state'][example_state_key][state_subkey],
- pt.Tensor)
- == isinstance(
- opt_state_dict['state'][state_key][state_subkey],
- pt.Tensor))
+ assert isinstance(
+ opt_state_dict["state"][example_state_key][state_subkey], torch.Tensor
+ ) == isinstance(
+ opt_state_dict["state"][state_key][state_subkey], torch.Tensor
+ )
- state_subkeys = list(opt_state_dict['state'][example_state_key].keys())
+ state_subkeys = list(opt_state_dict["state"][example_state_key].keys())
# Tags will record whether the value associated to the subkey is a
# tensor or not.
state_subkey_tags = []
for state_subkey in state_subkeys:
if isinstance(
- opt_state_dict['state'][example_state_key][state_subkey],
- pt.Tensor
+ opt_state_dict["state"][example_state_key][state_subkey], torch.Tensor
):
- state_subkey_tags.append('istensor')
+ state_subkey_tags.append("istensor")
else:
- state_subkey_tags.append('')
+ state_subkey_tags.append("")
state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags))
# Forming the flattened dict, using a concatenation of group index,
# subindex, tag, and subkey inserted into the flattened dict key -
# needed for reconstruction.
nb_params_per_group = []
- for group_idx, group in enumerate(opt_state_dict['param_groups']):
- for idx, param_id in enumerate(group['params']):
+ for group_idx, group in enumerate(opt_state_dict["param_groups"]):
+ for idx, param_id in enumerate(group["params"]):
for subkey, tag in state_subkeys_and_tags:
- if tag == 'istensor':
- new_v = opt_state_dict['state'][param_id][
- subkey].cpu().numpy()
+ if tag == "istensor":
+ new_v = opt_state_dict["state"][param_id][subkey].cpu().numpy()
else:
- new_v = np.array(
- [opt_state_dict['state'][param_id][subkey]]
- )
- derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v
+ new_v = np.array([opt_state_dict["state"][param_id][subkey]])
+ derived_opt_state_dict[
+ f"__opt_state_{group_idx}_{idx}_{tag}_{subkey}"
+ ] = new_v
nb_params_per_group.append(idx + 1)
# group lengths are also helpful for reconstructing
# original opt_state_dict structure
- derived_opt_state_dict['__opt_group_lengths'] = np.array(
- nb_params_per_group
- )
+ derived_opt_state_dict["__opt_group_lengths"] = np.array(nb_params_per_group)
return derived_opt_state_dict
@@ -568,38 +583,38 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device):
"""
state_subkeys_and_tags = []
for key in derived_opt_state_dict:
- if key.startswith('__opt_state_0_0_'):
+ if key.startswith("__opt_state_0_0_"):
stripped_key = key[16:]
- if stripped_key.startswith('istensor_'):
- this_tag = 'istensor'
+ if stripped_key.startswith("istensor_"):
+ this_tag = "istensor"
subkey = stripped_key[9:]
else:
- this_tag = ''
+ this_tag = ""
subkey = stripped_key[1:]
state_subkeys_and_tags.append((subkey, this_tag))
- opt_state_dict = {'param_groups': [], 'state': {}}
+ opt_state_dict = {"param_groups": [], "state": {}}
nb_params_per_group = list(
- derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32)
+ derived_opt_state_dict.pop("__opt_group_lengths").astype(np.int32)
)
# Construct the expanded dict.
for group_idx, nb_params in enumerate(nb_params_per_group):
- these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)]
- opt_state_dict['param_groups'].append({'params': these_group_ids})
+ these_group_ids = [f"{group_idx}_{idx}" for idx in range(nb_params)]
+ opt_state_dict["param_groups"].append({"params": these_group_ids})
for this_id in these_group_ids:
- opt_state_dict['state'][this_id] = {}
+ opt_state_dict["state"][this_id] = {}
for subkey, tag in state_subkeys_and_tags:
- flat_key = f'__opt_state_{this_id}_{tag}_{subkey}'
- if tag == 'istensor':
- new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key))
+ flat_key = f"__opt_state_{this_id}_{tag}_{subkey}"
+ if tag == "istensor":
+ new_v = torch.from_numpy(derived_opt_state_dict.pop(flat_key))
else:
# Here (for currrently supported optimizers) the subkey
# should be 'step' and the length of array should be one.
- assert subkey == 'step'
+ assert subkey == "step"
assert len(derived_opt_state_dict[flat_key]) == 1
new_v = int(derived_opt_state_dict.pop(flat_key))
- opt_state_dict['state'][this_id][subkey] = new_v
+ opt_state_dict["state"][this_id][subkey] = new_v
# sanity check that we did not miss any optimizer state
assert len(derived_opt_state_dict) == 0
@@ -617,11 +632,11 @@ def _get_optimizer_state(optimizer):
# Optimizer state might not have some parts representing frozen parameters
# So we do not synchronize them
- param_keys_with_state = set(opt_state_dict['state'].keys())
- for group in opt_state_dict['param_groups']:
- local_param_set = set(group['params'])
+ param_keys_with_state = set(opt_state_dict["state"].keys())
+ for group in opt_state_dict["param_groups"]:
+ local_param_set = set(group["params"])
params_to_sync = local_param_set & param_keys_with_state
- group['params'] = sorted(params_to_sync)
+ group["params"] = sorted(params_to_sync)
derived_opt_state_dict = _derive_opt_state_dict(opt_state_dict)
@@ -637,15 +652,14 @@ def _set_optimizer_state(optimizer, device, derived_opt_state_dict):
derived_opt_state_dict:
"""
- temp_state_dict = expand_derived_opt_state_dict(
- derived_opt_state_dict, device)
+ temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, device)
# FIXME: Figure out whether or not this breaks learning rate
# scheduling and the like.
# Setting default values.
# All optimizer.defaults are considered as not changing over course of
# training.
- for group in temp_state_dict['param_groups']:
+ for group in temp_state_dict["param_groups"]:
for k, v in optimizer.defaults.items():
group[k] = v
@@ -664,9 +678,11 @@ def to_cpu_numpy(state):
for k, v in state.items():
# When restoring, we currently assume all values are tensors.
- if not pt.is_tensor(v):
- raise ValueError('We do not currently support non-tensors '
- 'coming from model.state_dict()')
+ if not torch.is_tensor(v):
+ raise ValueError(
+ "We do not currently support non-tensors "
+ "coming from model.state_dict()"
+ )
# get as a numpy array, making sure is on cpu
state[k] = v.cpu().numpy()
return state