Skip to content

Commit

Permalink
Added weighted_average aggregation function under openfl.experimental…
Browse files Browse the repository at this point in the history
….interface.{keras,torch}.aggregation_funtions

Signed-off-by: Parth Mandaliya <[email protected]>
Signed-off-by: Parth Mandaliya <[email protected]>
  • Loading branch information
ParthM-GitHub authored and ParthMandaliya committed Oct 5, 2023
1 parent f542cdf commit b5c5f2f
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 0 deletions.
7 changes: 7 additions & 0 deletions openfl/experimental/interface/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimental.interface.keras package."""

from .aggregation_functions import WeightedAverage

__all__ = ["WeightedAverage", ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimenal.interface.keras.aggregation_functions package."""

from .weighted_average import WeightedAverage

__all__ = ["WeightedAverage", ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimental.interface.keras.aggregation_functions.weighted_average package."""


class WeightedAverage:
"""Weighted average aggregation for keras or tensorflow."""

def __init__(self) -> None:
"""
WeightedAverage class for Keras or Tensorflow library.
"""
raise NotImplementedError("WeightedAverage for keras will be implemented in the future.")
7 changes: 7 additions & 0 deletions openfl/experimental/interface/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimental.interface.torch package."""

from .aggregation_functions import WeightedAverage

__all__ = ["WeightedAverage", ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimenal.interface.torch.aggregation_functions package."""

from .weighted_average import WeightedAverage

__all__ = ["WeightedAverage", ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""openfl.experimental.interface.torch.aggregation_functions.weighted_average package."""

import collections
import numpy as np
import torch as pt


def weighted_average(tensors, weights):
"""Compute weighted average."""
return np.average(tensors, weights=weights, axis=0)


class WeightedAverage:
"""Weighted average aggregation."""

def __call__(self, objects_list, weights_list) -> np.ndarray:
"""
Compute weighted average of models, optimizers, loss, or accuracy metrics.
For taking weighted average of optimizer do the following steps:
1. Call "_get_optimizer_state" (openfl.federated.task.runner_pt._get_optimizer_state)
pass optimizer to it, to take optimizer state dictionary.
2. Pass optimizer state dictionaries list to here.
3. To set the weighted average optimizer state dictionary back to optimizer,
call "_set_optimizer_state" (openfl.federated.task.runner_pt._set_optimizer_state)
and pass optimizer, device, and optimizer dictionary received in step 2.
Args:
objects_list: List of objects for which weighted average is to be computed.
- List of Model state dictionaries , or
- List of Metrics (Loss or accuracy), or
- List of optimizer state dictionaries (following steps need to be performed)
1. Obtain optimizer state dictionary by invoking "_get_optimizer_state"
(openfl.federated.task.runner_pt._get_optimizer_state).
2. Create a list of optimizer state dictionary obtained in step - 1
Invoke WeightedAverage on this list.
3. Invoke "_set_optimizer_state" to set weighted average of optimizer
state back to optimizer (openfl.federated.task.runner_pt._set_optimizer_state).
weights_list: Weight for each element in the list.
Returns:
dict: For model or optimizer
float: For Loss or Accuracy metrics
"""
# Check the type of first element of tensors list
if type(objects_list[0]) in (dict, collections.OrderedDict):
optimizer = False
# If __opt_state_needed found then optimizer state dictionary is passed
if "__opt_state_needed" in objects_list[0]:
optimizer = True
# Remove __opt_state_needed from all state dictionary in list, and
# check if weightedaverage of optimizer can be taken.
for tensor in objects_list:
error_msg = "Optimizer is stateless, WeightedAverage cannot be taken"
assert tensor.pop("__opt_state_needed") == "true", error_msg

tmp_list = []
# # Take keys in order to rebuild the state dictionary taking keys back up
for tensor in objects_list:
# Append values of each state dictionary in list
# If type(value) is Tensor then it needs to be detached
tmp_list.append(np.array([value.detach() if isinstance(value, pt.Tensor) else value
for value in tensor.values()], dtype=object))
# Take weighted average of list of arrays
# new_params passed is weighted average of each array in tmp_list
new_params = weighted_average(tmp_list, weights_list)
new_state = {}
# Take weighted average parameters and building a dictionary
for i, k in enumerate(objects_list[0].keys()):
if optimizer:
new_state[k] = new_params[i]
else:
new_state[k] = pt.from_numpy(new_params[i].numpy())
return new_state
else:
return weighted_average(objects_list, weights_list)
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def run(self):
'openfl.databases.utilities',
'openfl.experimental',
'openfl.experimental.interface',
'openfl.experimental.interface.keras',
'openfl.experimental.interface.keras.aggregation_functions',
'openfl.experimental.interface.torch',
'openfl.experimental.interface.torch.aggregation_functions',
'openfl.experimental.placement',
'openfl.experimental.runtime',
'openfl.experimental.utilities',
Expand Down

0 comments on commit b5c5f2f

Please sign in to comment.