diff --git a/openfl-tutorials/Federated_FedProx_Keras_MNIST_Tutorial.ipynb b/openfl-tutorials/Federated_FedProx_Keras_MNIST_Tutorial.ipynb index cc0dcc1a9c..c4ca0fda31 100644 --- a/openfl-tutorials/Federated_FedProx_Keras_MNIST_Tutorial.ipynb +++ b/openfl-tutorials/Federated_FedProx_Keras_MNIST_Tutorial.ipynb @@ -10,9 +10,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "#Install Tensorflow and MNIST dataset if not installed\n", @@ -293,13 +291,12 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "#Run experiment, return trained FederatedModel\n", - "final_fl_model = fx.run_experiment(collaborators,override_config={'aggregator.settings.rounds_to_train':5, 'collaborator.settings.opt_treatment': 'CONTINUE_GLOBAL'})" + "from openfl.utilities.enum_types import OptTreatment\n", + "final_fl_model = fx.run_experiment(collaborators,override_config={'aggregator.settings.rounds_to_train':5, 'collaborator.settings.opt_treatment': OptTreatment.CONTINUE_GLOBAL})" ] }, { @@ -354,7 +351,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -367,8 +364,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/openfl-tutorials/interactive_api/MXNet_landmarks/workspace/MXNet_landmarks.ipynb b/openfl-tutorials/interactive_api/MXNet_landmarks/workspace/MXNet_landmarks.ipynb index 023990209e..ab1dc80e93 100644 --- a/openfl-tutorials/interactive_api/MXNet_landmarks/workspace/MXNet_landmarks.ipynb +++ b/openfl-tutorials/interactive_api/MXNet_landmarks/workspace/MXNet_landmarks.ipynb @@ -429,13 +429,16 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=MI,\n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=10,\n", - " opt_treatment=\"CONTINUE_GLOBAL\",\n", - " device_assignment_policy=\"CUDA_PREFERRED\",\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED,\n", ")" ] }, diff --git a/openfl-tutorials/interactive_api/PyTorch_DogsCats_ViT/workspace/PyTorch_DogsCats_ViT.ipynb b/openfl-tutorials/interactive_api/PyTorch_DogsCats_ViT/workspace/PyTorch_DogsCats_ViT.ipynb index e5e6414b9b..8b66dc2b53 100644 --- a/openfl-tutorials/interactive_api/PyTorch_DogsCats_ViT/workspace/PyTorch_DogsCats_ViT.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_DogsCats_ViT/workspace/PyTorch_DogsCats_ViT.ipynb @@ -510,12 +510,15 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI,\n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=5,\n", - " opt_treatment='CONTINUE_GLOBAL',\n", - " device_assignment_policy='CUDA_PREFERRED')" + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)" ] }, { diff --git a/openfl-tutorials/interactive_api/PyTorch_Histology/workspace/pytorch_histology.ipynb b/openfl-tutorials/interactive_api/PyTorch_Histology/workspace/pytorch_histology.ipynb index 02888ecf9b..32fc3feaa9 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Histology/workspace/pytorch_histology.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Histology/workspace/pytorch_histology.ipynb @@ -463,12 +463,16 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=model_interface, \n", " task_keeper=task_interface,\n", " data_loader=fed_dataset,\n", " rounds_to_train=5,\n", - " opt_treatment='CONTINUE_GLOBAL'\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED,\n", ")" ] }, @@ -487,8 +491,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.9" } }, "nbformat": 4, diff --git a/openfl-tutorials/interactive_api/PyTorch_Histology_FedCurv/workspace/pytorch_histology.ipynb b/openfl-tutorials/interactive_api/PyTorch_Histology_FedCurv/workspace/pytorch_histology.ipynb index f062d97fdd..2bff54f29b 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Histology_FedCurv/workspace/pytorch_histology.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Histology_FedCurv/workspace/pytorch_histology.ipynb @@ -474,12 +474,16 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=model_interface, \n", " task_keeper=task_interface,\n", " data_loader=fed_dataset,\n", " rounds_to_train=5,\n", - " opt_treatment='CONTINUE_GLOBAL'\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED\n", ")" ] }, diff --git a/openfl-tutorials/interactive_api/PyTorch_Huggingface_transformers_SUPERB/workspace/PyTorch_Huggingface_transformers_SUPERB.ipynb b/openfl-tutorials/interactive_api/PyTorch_Huggingface_transformers_SUPERB/workspace/PyTorch_Huggingface_transformers_SUPERB.ipynb index 3ab32ef11b..faa28cd93f 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Huggingface_transformers_SUPERB/workspace/PyTorch_Huggingface_transformers_SUPERB.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Huggingface_transformers_SUPERB/workspace/PyTorch_Huggingface_transformers_SUPERB.ipynb @@ -438,13 +438,16 @@ "metadata": {}, "outputs": [], "source": [ + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=MI,\n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=2,\n", - " opt_treatment=\"CONTINUE_GLOBAL\",\n", - " device_assignment_policy=\"CUDA_PREFERRED\",\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED,\n", ")" ] }, diff --git a/openfl-tutorials/interactive_api/PyTorch_Kvasir_UNet/workspace/PyTorch_Kvasir_UNet.ipynb b/openfl-tutorials/interactive_api/PyTorch_Kvasir_UNet/workspace/PyTorch_Kvasir_UNet.ipynb index 91169b5e22..a4d7227ff5 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Kvasir_UNet/workspace/PyTorch_Kvasir_UNet.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Kvasir_UNet/workspace/PyTorch_Kvasir_UNet.ipynb @@ -484,12 +484,15 @@ "# If I use autoreload I got a pickling error\n", "\n", "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI, \n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=2,\n", - " opt_treatment='CONTINUE_GLOBAL',\n", - " device_assignment_policy='CUDA_PREFERRED')\n" + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)\n" ] }, { @@ -584,7 +587,7 @@ "source": [ "MI = ModelInterface(model=best_model, optimizer=optimizer_adam, framework_plugin=framework_adapter)\n", "fl_experiment.start(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=4, \\\n", - " opt_treatment='CONTINUE_GLOBAL')" + " opt_treatment=OptTreatment.CONTINUE_GLOBAL)" ] }, { diff --git a/openfl-tutorials/interactive_api/PyTorch_Lightning_MNIST_GAN/workspace/PyTorch_Lightning_GAN.ipynb b/openfl-tutorials/interactive_api/PyTorch_Lightning_MNIST_GAN/workspace/PyTorch_Lightning_GAN.ipynb index 26cece8691..e0cdcc0c48 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Lightning_MNIST_GAN/workspace/PyTorch_Lightning_GAN.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Lightning_MNIST_GAN/workspace/PyTorch_Lightning_GAN.ipynb @@ -596,13 +596,16 @@ "metadata": {}, "outputs": [], "source": [ + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=MI,\n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=10,\n", - " opt_treatment=\"CONTINUE_GLOBAL\",\n", - " device_assignment_policy=\"CUDA_PREFERRED\",\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED,\n", ")" ] }, diff --git a/openfl-tutorials/interactive_api/PyTorch_MVTec_PatchSVDD/workspace/PatchSVDD_with_Director.ipynb b/openfl-tutorials/interactive_api/PyTorch_MVTec_PatchSVDD/workspace/PatchSVDD_with_Director.ipynb index 351934772a..a61ea3abdd 100644 --- a/openfl-tutorials/interactive_api/PyTorch_MVTec_PatchSVDD/workspace/PatchSVDD_with_Director.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_MVTec_PatchSVDD/workspace/PatchSVDD_with_Director.ipynb @@ -965,12 +965,15 @@ "# If I use autoreload I got a pickling error\n", "\n", "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI, \n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=10,\n", - " opt_treatment='CONTINUE_GLOBAL',\n", - " device_assignment_policy='CUDA_PREFERRED')\n" + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)\n" ] }, { diff --git a/openfl-tutorials/interactive_api/PyTorch_Market_Re-ID/workspace/PyTorch_Market_Re-ID.ipynb b/openfl-tutorials/interactive_api/PyTorch_Market_Re-ID/workspace/PyTorch_Market_Re-ID.ipynb index 0a50d68895..4d19b0a755 100644 --- a/openfl-tutorials/interactive_api/PyTorch_Market_Re-ID/workspace/PyTorch_Market_Re-ID.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_Market_Re-ID/workspace/PyTorch_Market_Re-ID.ipynb @@ -543,11 +543,15 @@ "# If I use autoreload I got a pickling error\n", "\n", "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI, \n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=3,\n", - " opt_treatment='RESET')" + " opt_treatment=OptTreatment.RESET,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)" ] }, { @@ -590,4 +594,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb b/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb index bbfadfd086..b65c9db2cc 100644 --- a/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb +++ b/openfl-tutorials/interactive_api/PyTorch_TinyImageNet/workspace/pytorch_tinyimagenet.ipynb @@ -450,12 +450,16 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(\n", " model_provider=model_interface, \n", " task_keeper=task_interface,\n", " data_loader=fed_dataset,\n", " rounds_to_train=5,\n", - " opt_treatment='CONTINUE_GLOBAL'\n", + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED\n", ")" ] }, diff --git a/openfl-tutorials/interactive_api/Tensorflow_MNIST/workspace/Tensorflow_MNIST.ipynb b/openfl-tutorials/interactive_api/Tensorflow_MNIST/workspace/Tensorflow_MNIST.ipynb index 5c818b2f34..bf496eaddc 100644 --- a/openfl-tutorials/interactive_api/Tensorflow_MNIST/workspace/Tensorflow_MNIST.ipynb +++ b/openfl-tutorials/interactive_api/Tensorflow_MNIST/workspace/Tensorflow_MNIST.ipynb @@ -360,11 +360,15 @@ "outputs": [], "source": [ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI, \n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=5,\n", - " opt_treatment='CONTINUE_GLOBAL')" + " opt_treatment=OptTreatment.CONTINUE_GLOBAL,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)" ] }, { diff --git a/openfl-tutorials/interactive_api/Tensorflow_Word_Prediction/workspace/Tensorflow_Word_Prediction.ipynb b/openfl-tutorials/interactive_api/Tensorflow_Word_Prediction/workspace/Tensorflow_Word_Prediction.ipynb index 8ebaae13df..ccefb8a39a 100644 --- a/openfl-tutorials/interactive_api/Tensorflow_Word_Prediction/workspace/Tensorflow_Word_Prediction.ipynb +++ b/openfl-tutorials/interactive_api/Tensorflow_Word_Prediction/workspace/Tensorflow_Word_Prediction.ipynb @@ -376,11 +376,15 @@ "# If I use autoreload I got a pickling error\n", "\n", "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "from openfl.utilities.enum_types import DevicePolicy\n", + "from openfl.utilities.enum_types import OptTreatment\n", + "\n", "fl_experiment.start(model_provider=MI, \n", " task_keeper=TI,\n", " data_loader=fed_dataset,\n", " rounds_to_train=20,\n", - " opt_treatment='RESET')" + " opt_treatment=OptTreatment.RESET,\n", + " device_assignment_policy=DevicePolicy.CUDA_PREFERRED)" ] }, { diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 645e23fc9d..20dec77fae 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -3,7 +3,6 @@ """Collaborator module.""" -from enum import Enum from logging import getLogger from time import sleep from typing import Tuple @@ -12,35 +11,11 @@ from openfl.pipelines import NoCompressionPipeline from openfl.pipelines import TensorCodec from openfl.protocols import utils +from openfl.utilities.enum_types import DevicePolicy +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import TensorKey -class DevicePolicy(Enum): - """Device assignment policy.""" - - CPU_ONLY = 1 - - CUDA_PREFERRED = 2 - - -class OptTreatment(Enum): - """Optimizer Methods. - - - RESET tells each collaborator to reset the optimizer state at the beginning - of each round. - - - CONTINUE_LOCAL tells each collaborator to continue with the local optimizer - state from the previous round. - - - CONTINUE_GLOBAL tells each collaborator to continue with the federally - averaged optimizer state from the previous round. - """ - - RESET = 1 - CONTINUE_LOCAL = 2 - CONTINUE_GLOBAL = 3 - - class Collaborator: r"""The Collaborator object class. @@ -49,8 +24,8 @@ class Collaborator: aggregator_uuid: The unique id for the client federation_uuid: The unique id for the federation model: The model - opt_treatment* (string): The optimizer state treatment (Defaults to - "CONTINUE_GLOBAL", which is aggreagated state from previous round.) + opt_treatment* (enum.Enum): The optimizer state treatment (Defaults to + OptTreatment.CONTINUE_GLOBAL, which is aggreagated state from previous round.) compression_pipeline: The compression pipeline (Defaults to None) @@ -74,8 +49,8 @@ def __init__(self, client, task_runner, task_config, - opt_treatment='RESET', - device_assignment_policy='CPU_ONLY', + opt_treatment=OptTreatment.RESET, + device_assignment_policy=DevicePolicy.CPU_ONLY, delta_updates=False, compression_pipeline=None, db_store_rounds=1, @@ -105,23 +80,10 @@ def __init__(self, self.logger = getLogger(__name__) - # RESET/CONTINUE_LOCAL/CONTINUE_GLOBAL - if hasattr(OptTreatment, opt_treatment): - self.opt_treatment = OptTreatment[opt_treatment] - else: - self.logger.error(f'Unknown opt_treatment: {opt_treatment.name}.') - raise NotImplementedError(f'Unknown opt_treatment: {opt_treatment}.') - - if hasattr(DevicePolicy, device_assignment_policy): - self.device_assignment_policy = DevicePolicy[device_assignment_policy] - else: - self.logger.error('Unknown device_assignment_policy: ' - f'{device_assignment_policy.name}.') - raise NotImplementedError( - f'Unknown device_assignment_policy: {device_assignment_policy}.' - ) + self.opt_treatment = opt_treatment + self.device_assignment_policy = device_assignment_policy - self.task_runner.set_optimizer_treatment(self.opt_treatment.name) + self.task_runner.set_optimizer_treatment(self.opt_treatment) def set_available_devices(self, cuda: Tuple[str] = ()): """ diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 7ff0176417..61ffb492d5 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -18,6 +18,7 @@ from openfl.interface.cli_helper import WORKSPACE from openfl.transport import AggregatorGRPCClient from openfl.transport import AggregatorGRPCServer +from openfl.utilities.enum_types import OptTreatment from openfl.utilities.utils import getfqdn_env SETTINGS = 'settings' @@ -458,6 +459,16 @@ def get_collaborator(self, collaborator_name, root_certificate=None, private_key defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe() defaults[SETTINGS]['task_config'] = self.config.get('tasks', {}) + + opt_treatment = defaults[SETTINGS]['opt_treatment'] + if isinstance(opt_treatment, str) and hasattr(OptTreatment, opt_treatment): + defaults[SETTINGS]['opt_treatment'] = OptTreatment[opt_treatment].value + elif isinstance(opt_treatment, int) and OptTreatment(opt_treatment): + pass + else: + self.logger.error(f'Unknown opt_treatment: {opt_treatment}.') + raise NotImplementedError(f'Unknown opt_treatment: {opt_treatment}.') + if client is not None: defaults[SETTINGS]['client'] = client else: diff --git a/openfl/federated/task/runner_fe.py b/openfl/federated/task/runner_fe.py index acace8c61b..568c40ab79 100644 --- a/openfl/federated/task/runner_fe.py +++ b/openfl/federated/task/runner_fe.py @@ -4,6 +4,7 @@ import numpy as np +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import split_tensor_dict_for_holdouts from openfl.utilities import TensorKey from .runner import TaskRunner @@ -148,7 +149,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # A work around could involve doing a single epoch of training # on random data to get the optimizer names, and then throwing away # the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: self.initialize_tensorkeys_for_functions(with_opt_vars=True) return global_tensor_dict, local_tensor_dict diff --git a/openfl/federated/task/runner_keras.py b/openfl/federated/task/runner_keras.py index 93ac40b1d0..389cf180fe 100644 --- a/openfl/federated/task/runner_keras.py +++ b/openfl/federated/task/runner_keras.py @@ -12,6 +12,7 @@ import numpy as np from openfl.utilities import Metric +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import split_tensor_dict_for_holdouts from openfl.utilities import TensorKey from .runner import TaskRunner @@ -51,10 +52,10 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): ------- None """ - if self.opt_treatment == 'RESET': + if self.opt_treatment == OptTreatment.RESET: self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - elif (round_num > 0 and self.opt_treatment == 'CONTINUE_GLOBAL' + elif (round_num > 0 and self.opt_treatment == OptTreatment.CONTINUE_GLOBAL and not validation): self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) else: @@ -139,7 +140,7 @@ def train(self, col_name, round_num, input_tensor_dict, # these are only created after training occurs. A work around could # involve doing a single epoch of training on random data to get the # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: self.initialize_tensorkeys_for_functions(with_opt_vars=True) return global_tensor_dict, local_tensor_dict diff --git a/openfl/federated/task/runner_pt.py b/openfl/federated/task/runner_pt.py index 2905f09798..5e15589a41 100644 --- a/openfl/federated/task/runner_pt.py +++ b/openfl/federated/task/runner_pt.py @@ -13,6 +13,7 @@ import tqdm from openfl.utilities import Metric +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import split_tensor_dict_for_holdouts from openfl.utilities import TensorKey from .runner import TaskRunner @@ -63,11 +64,11 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): Returns: None """ - if self.opt_treatment == 'RESET': + if self.opt_treatment == OptTreatment.RESET: self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) elif (self.training_round_completed - and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation): + and self.opt_treatment == OptTreatment.CONTINUE_GLOBAL and not validation): self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) @@ -207,7 +208,7 @@ def train_batches(self, col_name, round_num, input_tensor_dict, # these are only created after training occurs. A work around could # involve doing a single epoch of training on random data to get the # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: self.initialize_tensorkeys_for_functions(with_opt_vars=True) # This will signal that the optimizer values are now present, diff --git a/openfl/federated/task/runner_tf.py b/openfl/federated/task/runner_tf.py index 620a0ca0dd..da52a85908 100644 --- a/openfl/federated/task/runner_tf.py +++ b/openfl/federated/task/runner_tf.py @@ -7,6 +7,7 @@ import tensorflow.compat.v1 as tf from tqdm import tqdm +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import split_tensor_dict_for_holdouts from openfl.utilities import TensorKey from .runner import TaskRunner @@ -74,10 +75,10 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): Returns: None """ - if self.opt_treatment == 'RESET': + if self.opt_treatment == OptTreatment.RESET: self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - elif (round_num > 0 and self.opt_treatment == 'CONTINUE_GLOBAL' + elif (round_num > 0 and self.opt_treatment == OptTreatment.CONTINUE_GLOBAL and not validation): self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) else: @@ -172,7 +173,7 @@ def train_batches(self, col_name, round_num, input_tensor_dict, # these are only created after training occurs. A work around could # involve doing a single epoch of training on random data to get the # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: self.initialize_tensorkeys_for_functions(with_opt_vars=True) return global_tensor_dict, local_tensor_dict diff --git a/openfl/federated/task/task_runner.py b/openfl/federated/task/task_runner.py index 05023e4743..b486017398 100644 --- a/openfl/federated/task/task_runner.py +++ b/openfl/federated/task/task_runner.py @@ -6,6 +6,7 @@ import numpy as np +from openfl.utilities.enum_types import OptTreatment from openfl.utilities import split_tensor_dict_for_holdouts from openfl.utilities import TensorKey @@ -63,7 +64,7 @@ def _prepare_tensorkeys_for_agggregation(self, metric_dict, validation_flag, # A work around could involve doing a single epoch of training # on random data to get the optimizer names, # and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: self.initialize_tensorkeys_for_functions(with_opt_vars=True) # This will signal that the optimizer values are now present, @@ -154,7 +155,7 @@ def __init__(self, **kwargs): self.TASK_REGISTRY = {} # Why is it here - self.opt_treatment = 'RESET' + self.opt_treatment = OptTreatment.RESET self.tensor_dict_split_fn_kwargs = {} self.required_tensorkeys_for_function = {} @@ -198,7 +199,7 @@ def set_framework_adapter(self, framework_adapter): of the model with the purpose to make a list of parameters to be aggregated. """ self.framework_adapter = framework_adapter - if self.opt_treatment == 'CONTINUE_GLOBAL': + if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: aggregate_optimizer_parameters = True else: aggregate_optimizer_parameters = False @@ -221,11 +222,11 @@ def rebuild_model(self, input_tensor_dict, validation=False, device='cpu'): Returns: None """ - if self.opt_treatment == 'RESET': + if self.opt_treatment == OptTreatment.RESET: self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False, device=device) elif (self.training_round_completed - and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation): + and self.opt_treatment == OptTreatment.CONTINUE_GLOBAL and not validation): self.set_tensor_dict(input_tensor_dict, with_opt_vars=True, device=device) else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False, device=device) diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index a1e3194d5d..49c5f7fd40 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -18,6 +18,8 @@ from openfl.component.assigner.tasks import Task from openfl.component.assigner.tasks import TrainTask from openfl.component.assigner.tasks import ValidateTask +from openfl.utilities.enum_types import DevicePolicy +from openfl.utilities.enum_types import OptTreatment from openfl.federated import Plan from openfl.interface.cli import setup_logging from openfl.interface.cli_helper import WORKSPACE @@ -104,7 +106,8 @@ def _rebuild_model(self, tensor_dict, upcoming_model_status=ModelStatus.BEST): self.logger.warning(warning_msg) else: - self.task_runner_stub.rebuild_model(tensor_dict, validation=True, device='cpu') + self.task_runner_stub.rebuild_model( + tensor_dict, validation=True, device='cpu') self.current_model_status = upcoming_model_status return deepcopy(self.task_runner_stub.model) @@ -166,8 +169,8 @@ def start(self, *, model_provider, task_keeper, data_loader, rounds_to_train: int, task_assigner=None, delta_updates: bool = False, - opt_treatment: str = 'RESET', - device_assignment_policy: str = 'CPU_ONLY', + opt_treatment: OptTreatment = OptTreatment.RESET, + device_assignment_policy: DevicePolicy = DevicePolicy.CPU_ONLY, pip_install_options: Tuple[str] = ()) -> None: """ Prepare workspace distribution and send to Director. @@ -194,6 +197,7 @@ def start(self, *, model_provider, task_keeper, data_loader, pip_install_options - tuple of options for the remote `pip install` calls, example: ('-f some.website', '--no-index') """ + if not task_assigner: task_assigner = self.define_task_assigner(task_keeper, rounds_to_train) @@ -361,9 +365,9 @@ def _prepare_plan(self, model_provider, data_loader, # Collaborator part plan.config['collaborator']['settings']['delta_updates'] = delta_updates - plan.config['collaborator']['settings']['opt_treatment'] = opt_treatment + plan.config['collaborator']['settings']['opt_treatment'] = opt_treatment.name plan.config['collaborator']['settings'][ - 'device_assignment_policy'] = device_assignment_policy + 'device_assignment_policy'] = device_assignment_policy.name # DataLoader part for setting, value in data_loader.kwargs.items(): diff --git a/openfl/utilities/enum_types.py b/openfl/utilities/enum_types.py new file mode 100644 index 0000000000..24c06785d7 --- /dev/null +++ b/openfl/utilities/enum_types.py @@ -0,0 +1,31 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Enum types for device policy and opt treatment.""" + +from enum import Enum + + +class DevicePolicy(Enum): + """Device assignment policy.""" + + CPU_ONLY = 1 + CUDA_PREFERRED = 2 + + +class OptTreatment(Enum): + """Optimizer Methods. + + - RESET tells each collaborator to reset the optimizer state at the beginning + of each round. + + - CONTINUE_LOCAL tells each collaborator to continue with the local optimizer + state from the previous round. + + - CONTINUE_GLOBAL tells each collaborator to continue with the federally + averaged optimizer state from the previous round. + """ + + RESET = 1 + CONTINUE_LOCAL = 2 + CONTINUE_GLOBAL = 3 diff --git a/tests/github/interactive_api_director/experiments/pytorch_kvasir_unet/experiment.py b/tests/github/interactive_api_director/experiments/pytorch_kvasir_unet/experiment.py index e4072fe220..4225ddc49b 100644 --- a/tests/github/interactive_api_director/experiments/pytorch_kvasir_unet/experiment.py +++ b/tests/github/interactive_api_director/experiments/pytorch_kvasir_unet/experiment.py @@ -11,6 +11,7 @@ from torch.utils.data import SubsetRandomSampler from torchvision import transforms as tsf +from openfl.utilities.enum_types import OptTreatment from openfl.interface.interactive_api.experiment import DataInterface from openfl.interface.interactive_api.experiment import FLExperiment from openfl.interface.interactive_api.experiment import ModelInterface @@ -145,7 +146,7 @@ def get_valid_data_size(self): task_keeper=task_interface, data_loader=fed_dataset, rounds_to_train=2, - opt_treatment='CONTINUE_GLOBAL') + opt_treatment=OptTreatment.CONTINUE_GLOBAL) fl_experiment.stream_metrics() best_model = fl_experiment.get_best_model() fl_experiment.remove_experiment_data() diff --git a/tests/github/interactive_api_director/experiments/tensorflow_mnist/experiment.py b/tests/github/interactive_api_director/experiments/tensorflow_mnist/experiment.py index 7d06ab4921..63597e3ea4 100644 --- a/tests/github/interactive_api_director/experiments/tensorflow_mnist/experiment.py +++ b/tests/github/interactive_api_director/experiments/tensorflow_mnist/experiment.py @@ -1,6 +1,7 @@ import time import tensorflow as tf # Create a federation +from openfl.utilities.enum_types import OptTreatment from openfl.interface.interactive_api.federation import Federation from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment from tests.github.interactive_api_director.experiments.tensorflow_mnist.dataset import FedDataset @@ -118,7 +119,7 @@ def validate(model, val_dataset, device): task_keeper=TI, data_loader=fed_dataset, rounds_to_train=2, - opt_treatment='CONTINUE_GLOBAL') + opt_treatment=OptTreatment.CONTINUE_GLOBAL) fl_experiment.stream_metrics() best_model = fl_experiment.get_best_model() diff --git a/tests/openfl/component/collaborator/test_collaborator.py b/tests/openfl/component/collaborator/test_collaborator.py index 58cb342dd5..119995e2b0 100644 --- a/tests/openfl/component/collaborator/test_collaborator.py +++ b/tests/openfl/component/collaborator/test_collaborator.py @@ -8,6 +8,7 @@ import pytest from openfl.component.collaborator import Collaborator +from openfl.utilities.enum_types import OptTreatment from openfl.protocols import base_pb2 from openfl.utilities.types import TensorKey @@ -16,7 +17,7 @@ def collaborator_mock(): """Initialize the collaborator mock.""" col = Collaborator('col1', 'some_uuid', 'federation_uuid', - mock.Mock(), mock.Mock(), mock.Mock(), opt_treatment='RESET') + mock.Mock(), mock.Mock(), mock.Mock(), opt_treatment=OptTreatment.RESET) col.tensor_db = mock.Mock() return col