From b5c5f2fceaac2caad8e101626dc08e7186ac1132 Mon Sep 17 00:00:00 2001 From: Parth Mandaliya <parthx.mandaliya@intel.com> Date: Fri, 29 Sep 2023 17:23:29 +0530 Subject: [PATCH] Added weighted_average aggregation function under openfl.experimental.interface.{keras,torch}.aggregation_funtions Signed-off-by: Parth Mandaliya <parthx.mandaliya@intel.com> Signed-off-by: Parth Mandaliya <parth.mandaliya.007@gmail.com> --- .../experimental/interface/keras/__init__.py | 7 ++ .../keras/aggregation_functions/__init__.py | 7 ++ .../aggregation_functions/weighted_average.py | 13 ++++ .../experimental/interface/torch/__init__.py | 7 ++ .../torch/aggregation_functions/__init__.py | 7 ++ .../aggregation_functions/weighted_average.py | 77 +++++++++++++++++++ setup.py | 4 + 7 files changed, 122 insertions(+) create mode 100644 openfl/experimental/interface/keras/__init__.py create mode 100644 openfl/experimental/interface/keras/aggregation_functions/__init__.py create mode 100644 openfl/experimental/interface/keras/aggregation_functions/weighted_average.py create mode 100644 openfl/experimental/interface/torch/__init__.py create mode 100644 openfl/experimental/interface/torch/aggregation_functions/__init__.py create mode 100644 openfl/experimental/interface/torch/aggregation_functions/weighted_average.py diff --git a/openfl/experimental/interface/keras/__init__.py b/openfl/experimental/interface/keras/__init__.py new file mode 100644 index 0000000000..1d7d84eb7f --- /dev/null +++ b/openfl/experimental/interface/keras/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimental.interface.keras package.""" + +from .aggregation_functions import WeightedAverage + +__all__ = ["WeightedAverage", ] diff --git a/openfl/experimental/interface/keras/aggregation_functions/__init__.py b/openfl/experimental/interface/keras/aggregation_functions/__init__.py new file mode 100644 index 0000000000..94708487bc --- /dev/null +++ b/openfl/experimental/interface/keras/aggregation_functions/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimenal.interface.keras.aggregation_functions package.""" + +from .weighted_average import WeightedAverage + +__all__ = ["WeightedAverage", ] diff --git a/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py b/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py new file mode 100644 index 0000000000..326e57aece --- /dev/null +++ b/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py @@ -0,0 +1,13 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimental.interface.keras.aggregation_functions.weighted_average package.""" + + +class WeightedAverage: + """Weighted average aggregation for keras or tensorflow.""" + + def __init__(self) -> None: + """ + WeightedAverage class for Keras or Tensorflow library. + """ + raise NotImplementedError("WeightedAverage for keras will be implemented in the future.") diff --git a/openfl/experimental/interface/torch/__init__.py b/openfl/experimental/interface/torch/__init__.py new file mode 100644 index 0000000000..969f47b43a --- /dev/null +++ b/openfl/experimental/interface/torch/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimental.interface.torch package.""" + +from .aggregation_functions import WeightedAverage + +__all__ = ["WeightedAverage", ] diff --git a/openfl/experimental/interface/torch/aggregation_functions/__init__.py b/openfl/experimental/interface/torch/aggregation_functions/__init__.py new file mode 100644 index 0000000000..2afa83b219 --- /dev/null +++ b/openfl/experimental/interface/torch/aggregation_functions/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimenal.interface.torch.aggregation_functions package.""" + +from .weighted_average import WeightedAverage + +__all__ = ["WeightedAverage", ] diff --git a/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py b/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py new file mode 100644 index 0000000000..a91cadfa0d --- /dev/null +++ b/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py @@ -0,0 +1,77 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.experimental.interface.torch.aggregation_functions.weighted_average package.""" + +import collections +import numpy as np +import torch as pt + + +def weighted_average(tensors, weights): + """Compute weighted average.""" + return np.average(tensors, weights=weights, axis=0) + + +class WeightedAverage: + """Weighted average aggregation.""" + + def __call__(self, objects_list, weights_list) -> np.ndarray: + """ + Compute weighted average of models, optimizers, loss, or accuracy metrics. + For taking weighted average of optimizer do the following steps: + 1. Call "_get_optimizer_state" (openfl.federated.task.runner_pt._get_optimizer_state) + pass optimizer to it, to take optimizer state dictionary. + 2. Pass optimizer state dictionaries list to here. + 3. To set the weighted average optimizer state dictionary back to optimizer, + call "_set_optimizer_state" (openfl.federated.task.runner_pt._set_optimizer_state) + and pass optimizer, device, and optimizer dictionary received in step 2. + + Args: + objects_list: List of objects for which weighted average is to be computed. + - List of Model state dictionaries , or + - List of Metrics (Loss or accuracy), or + - List of optimizer state dictionaries (following steps need to be performed) + 1. Obtain optimizer state dictionary by invoking "_get_optimizer_state" + (openfl.federated.task.runner_pt._get_optimizer_state). + 2. Create a list of optimizer state dictionary obtained in step - 1 + Invoke WeightedAverage on this list. + 3. Invoke "_set_optimizer_state" to set weighted average of optimizer + state back to optimizer (openfl.federated.task.runner_pt._set_optimizer_state). + weights_list: Weight for each element in the list. + + Returns: + dict: For model or optimizer + float: For Loss or Accuracy metrics + """ + # Check the type of first element of tensors list + if type(objects_list[0]) in (dict, collections.OrderedDict): + optimizer = False + # If __opt_state_needed found then optimizer state dictionary is passed + if "__opt_state_needed" in objects_list[0]: + optimizer = True + # Remove __opt_state_needed from all state dictionary in list, and + # check if weightedaverage of optimizer can be taken. + for tensor in objects_list: + error_msg = "Optimizer is stateless, WeightedAverage cannot be taken" + assert tensor.pop("__opt_state_needed") == "true", error_msg + + tmp_list = [] + # # Take keys in order to rebuild the state dictionary taking keys back up + for tensor in objects_list: + # Append values of each state dictionary in list + # If type(value) is Tensor then it needs to be detached + tmp_list.append(np.array([value.detach() if isinstance(value, pt.Tensor) else value + for value in tensor.values()], dtype=object)) + # Take weighted average of list of arrays + # new_params passed is weighted average of each array in tmp_list + new_params = weighted_average(tmp_list, weights_list) + new_state = {} + # Take weighted average parameters and building a dictionary + for i, k in enumerate(objects_list[0].keys()): + if optimizer: + new_state[k] = new_params[i] + else: + new_state[k] = pt.from_numpy(new_params[i].numpy()) + return new_state + else: + return weighted_average(objects_list, weights_list) diff --git a/setup.py b/setup.py index 845d960813..1b3b14ac74 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,10 @@ def run(self): 'openfl.databases.utilities', 'openfl.experimental', 'openfl.experimental.interface', + 'openfl.experimental.interface.keras', + 'openfl.experimental.interface.keras.aggregation_functions', + 'openfl.experimental.interface.torch', + 'openfl.experimental.interface.torch.aggregation_functions', 'openfl.experimental.placement', 'openfl.experimental.runtime', 'openfl.experimental.utilities',