diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index bd57ca48af..bfea92e2ed 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -614,9 +614,10 @@ def send_local_task_results(self, collaborator_name, round_number, task_name, collab_name for collab_name in self.authorized_cols if collab_name not in self.collaborators_done ] - if len(self.stragglers) != 0: self.logger.warning( - f"Identified straggler collaborators: {self.stragglers}" - ) + if len(self.stragglers) != 0: + self.logger.warning( + f"Identified straggler collaborators: {self.stragglers}" + ) self._end_of_round_check() def _process_named_tensor(self, named_tensor, collaborator_name): diff --git a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py index e503ce574b..2305825750 100644 --- a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py @@ -121,8 +121,8 @@ def __straggler_time_expired(self) -> bool: Determines if straggler_cutoff_time is elapsed. """ return ( - self.round_start_time is not None and - ((time.time() - self.round_start_time) > self.straggler_cutoff_time) + self.round_start_time is not None + and ((time.time() - self.round_start_time) > self.straggler_cutoff_time) ) def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: diff --git a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py index 6096a1eb48..1a8804c290 100644 --- a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py @@ -3,7 +3,6 @@ """Percentage based Straggler Handling function.""" from logging import getLogger -from typing import Callable from openfl.component.straggler_handling_functions import StragglerHandlingPolicy diff --git a/openfl/component/straggler_handling_functions/straggler_handling_function.py b/openfl/component/straggler_handling_functions/straggler_handling_function.py index 31443862bd..52104b6376 100644 --- a/openfl/component/straggler_handling_functions/straggler_handling_function.py +++ b/openfl/component/straggler_handling_functions/straggler_handling_function.py @@ -51,7 +51,7 @@ def straggler_cutoff_check( Number of collaborators finished. num_all_collaborators: int Total number of collaborators. - + Returns: bool """ diff --git a/openfl/native/fastestimator.py b/openfl/native/fastestimator.py index e2d659563a..5700f61613 100644 --- a/openfl/native/fastestimator.py +++ b/openfl/native/fastestimator.py @@ -75,9 +75,7 @@ def fit(self): aggregator = plan.get_aggregator() - model_states = { - collaborator: None for collaborator in plan.authorized_cols - } + model_states = dict.fromkeys(plan.authorized_cols, None) runners = {} save_dir = {} data_path = 1 diff --git a/openfl/utilities/optimizers/numpy/adam_optimizer.py b/openfl/utilities/optimizers/numpy/adam_optimizer.py index 8660a59855..79adb76bbb 100644 --- a/openfl/utilities/optimizers/numpy/adam_optimizer.py +++ b/openfl/utilities/optimizers/numpy/adam_optimizer.py @@ -72,7 +72,7 @@ def __init__( self.beta_1, self.beta_2 = betas self.initial_accumulator_value = initial_accumulator_value self.epsilon = epsilon - self.current_step: Dict[str, int] = {param_name: 0 for param_name in self.params} + self.current_step: Dict[str, int] = dict.fromkeys(self.params, 0) self.grads_first_moment, self.grads_second_moment = {}, {}