From b5c5f2fceaac2caad8e101626dc08e7186ac1132 Mon Sep 17 00:00:00 2001
From: Parth Mandaliya <parthx.mandaliya@intel.com>
Date: Fri, 29 Sep 2023 17:23:29 +0530
Subject: [PATCH] Added weighted_average aggregation function under
 openfl.experimental.interface.{keras,torch}.aggregation_funtions

Signed-off-by: Parth Mandaliya <parthx.mandaliya@intel.com>
Signed-off-by: Parth Mandaliya <parth.mandaliya.007@gmail.com>
---
 .../experimental/interface/keras/__init__.py  |  7 ++
 .../keras/aggregation_functions/__init__.py   |  7 ++
 .../aggregation_functions/weighted_average.py | 13 ++++
 .../experimental/interface/torch/__init__.py  |  7 ++
 .../torch/aggregation_functions/__init__.py   |  7 ++
 .../aggregation_functions/weighted_average.py | 77 +++++++++++++++++++
 setup.py                                      |  4 +
 7 files changed, 122 insertions(+)
 create mode 100644 openfl/experimental/interface/keras/__init__.py
 create mode 100644 openfl/experimental/interface/keras/aggregation_functions/__init__.py
 create mode 100644 openfl/experimental/interface/keras/aggregation_functions/weighted_average.py
 create mode 100644 openfl/experimental/interface/torch/__init__.py
 create mode 100644 openfl/experimental/interface/torch/aggregation_functions/__init__.py
 create mode 100644 openfl/experimental/interface/torch/aggregation_functions/weighted_average.py

diff --git a/openfl/experimental/interface/keras/__init__.py b/openfl/experimental/interface/keras/__init__.py
new file mode 100644
index 0000000000..1d7d84eb7f
--- /dev/null
+++ b/openfl/experimental/interface/keras/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimental.interface.keras package."""
+
+from .aggregation_functions import WeightedAverage
+
+__all__ = ["WeightedAverage", ]
diff --git a/openfl/experimental/interface/keras/aggregation_functions/__init__.py b/openfl/experimental/interface/keras/aggregation_functions/__init__.py
new file mode 100644
index 0000000000..94708487bc
--- /dev/null
+++ b/openfl/experimental/interface/keras/aggregation_functions/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimenal.interface.keras.aggregation_functions package."""
+
+from .weighted_average import WeightedAverage
+
+__all__ = ["WeightedAverage", ]
diff --git a/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py b/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py
new file mode 100644
index 0000000000..326e57aece
--- /dev/null
+++ b/openfl/experimental/interface/keras/aggregation_functions/weighted_average.py
@@ -0,0 +1,13 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimental.interface.keras.aggregation_functions.weighted_average package."""
+
+
+class WeightedAverage:
+    """Weighted average aggregation for keras or tensorflow."""
+
+    def __init__(self) -> None:
+        """
+        WeightedAverage class for Keras or Tensorflow library.
+        """
+        raise NotImplementedError("WeightedAverage for keras will be implemented in the future.")
diff --git a/openfl/experimental/interface/torch/__init__.py b/openfl/experimental/interface/torch/__init__.py
new file mode 100644
index 0000000000..969f47b43a
--- /dev/null
+++ b/openfl/experimental/interface/torch/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimental.interface.torch package."""
+
+from .aggregation_functions import WeightedAverage
+
+__all__ = ["WeightedAverage", ]
diff --git a/openfl/experimental/interface/torch/aggregation_functions/__init__.py b/openfl/experimental/interface/torch/aggregation_functions/__init__.py
new file mode 100644
index 0000000000..2afa83b219
--- /dev/null
+++ b/openfl/experimental/interface/torch/aggregation_functions/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimenal.interface.torch.aggregation_functions package."""
+
+from .weighted_average import WeightedAverage
+
+__all__ = ["WeightedAverage", ]
diff --git a/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py b/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py
new file mode 100644
index 0000000000..a91cadfa0d
--- /dev/null
+++ b/openfl/experimental/interface/torch/aggregation_functions/weighted_average.py
@@ -0,0 +1,77 @@
+# Copyright (C) 2020-2023 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+"""openfl.experimental.interface.torch.aggregation_functions.weighted_average package."""
+
+import collections
+import numpy as np
+import torch as pt
+
+
+def weighted_average(tensors, weights):
+    """Compute weighted average."""
+    return np.average(tensors, weights=weights, axis=0)
+
+
+class WeightedAverage:
+    """Weighted average aggregation."""
+
+    def __call__(self, objects_list, weights_list) -> np.ndarray:
+        """
+        Compute weighted average of models, optimizers, loss, or accuracy metrics.
+        For taking weighted average of optimizer do the following steps:
+        1.  Call "_get_optimizer_state" (openfl.federated.task.runner_pt._get_optimizer_state)
+            pass optimizer to it, to take optimizer state dictionary.
+        2.  Pass optimizer state dictionaries list to here.
+        3.  To set the weighted average optimizer state dictionary back to optimizer,
+            call "_set_optimizer_state" (openfl.federated.task.runner_pt._set_optimizer_state)
+            and pass optimizer, device, and optimizer dictionary received in step 2.
+
+        Args:
+            objects_list: List of objects for which weighted average is to be computed.
+            - List of Model state dictionaries , or
+            - List of Metrics (Loss or accuracy), or
+            - List of optimizer state dictionaries (following steps need to be performed)
+                1. Obtain optimizer state dictionary by invoking "_get_optimizer_state"
+                   (openfl.federated.task.runner_pt._get_optimizer_state).
+                2. Create a list of optimizer state dictionary obtained in step - 1
+                   Invoke WeightedAverage on this list.
+                3. Invoke "_set_optimizer_state" to set weighted average of optimizer
+                   state back to optimizer (openfl.federated.task.runner_pt._set_optimizer_state).
+            weights_list: Weight for each element in the list.
+
+        Returns:
+            dict: For model or optimizer
+            float: For Loss or Accuracy metrics
+        """
+        # Check the type of first element of tensors list
+        if type(objects_list[0]) in (dict, collections.OrderedDict):
+            optimizer = False
+            # If __opt_state_needed found then optimizer state dictionary is passed
+            if "__opt_state_needed" in objects_list[0]:
+                optimizer = True
+                # Remove __opt_state_needed from all state dictionary in list, and
+                # check if weightedaverage of optimizer can be taken.
+                for tensor in objects_list:
+                    error_msg = "Optimizer is stateless, WeightedAverage cannot be taken"
+                    assert tensor.pop("__opt_state_needed") == "true", error_msg
+
+            tmp_list = []
+            # # Take keys in order to rebuild the state dictionary taking keys back up
+            for tensor in objects_list:
+                # Append values of each state dictionary in list
+                # If type(value) is Tensor then it needs to be detached
+                tmp_list.append(np.array([value.detach() if isinstance(value, pt.Tensor) else value
+                                          for value in tensor.values()], dtype=object))
+            # Take weighted average of list of arrays
+            # new_params passed is weighted average of each array in tmp_list
+            new_params = weighted_average(tmp_list, weights_list)
+            new_state = {}
+            # Take weighted average parameters and building a dictionary
+            for i, k in enumerate(objects_list[0].keys()):
+                if optimizer:
+                    new_state[k] = new_params[i]
+                else:
+                    new_state[k] = pt.from_numpy(new_params[i].numpy())
+            return new_state
+        else:
+            return weighted_average(objects_list, weights_list)
diff --git a/setup.py b/setup.py
index 845d960813..1b3b14ac74 100644
--- a/setup.py
+++ b/setup.py
@@ -103,6 +103,10 @@ def run(self):
         'openfl.databases.utilities',
         'openfl.experimental',
         'openfl.experimental.interface',
+        'openfl.experimental.interface.keras',
+        'openfl.experimental.interface.keras.aggregation_functions',
+        'openfl.experimental.interface.torch',
+        'openfl.experimental.interface.torch.aggregation_functions',
         'openfl.experimental.placement',
         'openfl.experimental.runtime',
         'openfl.experimental.utilities',