From f552f3980a27c60a8f7ef33ddd9e852566a68924 Mon Sep 17 00:00:00 2001 From: manuelhsantana Date: Wed, 3 Jul 2024 18:35:19 -0700 Subject: [PATCH] Update aggregator docstring for fix lint warnings --- openfl/component/aggregator/aggregator.py | 415 +++++++++++----------- 1 file changed, 200 insertions(+), 215 deletions(-) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index da52576c89..ed90fcf96a 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -2,20 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 """Aggregator module.""" -import time import queue +import time from logging import getLogger -from openfl.interface.aggregation_functions import WeightedAverage -from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling +from openfl.component.straggler_handling_functions import \ + CutoffTimeBasedStragglerHandling from openfl.databases import TensorDB -from openfl.pipelines import NoCompressionPipeline -from openfl.pipelines import TensorCodec -from openfl.protocols import base_pb2 -from openfl.protocols import utils -from openfl.utilities import change_tags -from openfl.utilities import TaskResultKey -from openfl.utilities import TensorKey +from openfl.interface.aggregation_functions import WeightedAverage +from openfl.pipelines import NoCompressionPipeline, TensorCodec +from openfl.protocols import base_pb2, utils +from openfl.utilities import TaskResultKey, TensorKey, change_tags from openfl.utilities.logs import write_metric @@ -24,48 +21,50 @@ class Aggregator: Attributes: round_number (int): Current round number. - single_col_cert_common_name (str): Common name for single collaborator certificate. - straggler_handling_policy: Straggler handling policy. - _end_of_round_check_done (list of bool): List indicating whether end of round check is done for each round. + single_col_cert_common_name (str): Common name for single + collaborator certificate. + straggler_handling_policy: Policy for handling stragglers. + _end_of_round_check_done (list of bool): Indicates if end of round + check is done for each round. stragglers (list): List of stragglers. rounds_to_train (int): Number of rounds to train. - authorized_cols (list of str): The list of IDs of enrolled collaborators. + authorized_cols (list of str): IDs of enrolled collaborators. uuid (int): Aggregator UUID. federation_uuid (str): Federation UUID. - assigner: Assigner object. - quit_job_sent_to (list): List of collaborators to whom quit job is sent. - tensor_db (TensorDB): TensorDB object. + assigner: Object assigning tasks to collaborators. + quit_job_sent_to (list): Collaborators sent a quit job. + tensor_db (TensorDB): Object for tensor database. db_store_rounds* (int): Rounds to store in TensorDB. - logger: Logger object. - write_logs (bool): Whether to write logs. - log_metric_callback: Callback for log metric. - best_model_score (optional): Score of the best model. Defaults to None. + logger: Object for logging. + write_logs (bool): Flag to enable log writing. + log_metric_callback: Callback for logging metrics. + best_model_score (optional): Score of the best model. Defaults to + None. metric_queue (queue.Queue): Queue for metrics. - compression_pipeline: Compression pipeline. - tensor_codec (TensorCodec): Tensor codec. - init_state_path* (str): The location of the initial weight file. - best_state_path* (str): The file location to store the weight of the best model. - last_state_path* (str): The file location to store the latest weight. - best_tensor_dict (dict): Dictionary of the best tensors. - last_tensor_dict (dict): Dictionary of the last tensors. - collaborator_tensor_results (dict): Dictionary of collaborator tensor results. - collaborator_tasks_results (dict): Dictionary of collaborator tasks results. - collaborator_task_weight (dict): Dictionary of collaborator task weight. - - .. note:: + compression_pipeline: Pipeline for compressing data. + tensor_codec (TensorCodec): Codec for tensor compression. + init_state_path* (str): Initial weight file location. + best_state_path* (str): Where to store the best model weight. + last_state_path* (str): Where to store the latest model weight. + best_tensor_dict (dict): Dict of the best tensors. + last_tensor_dict (dict): Dict of the last tensors. + collaborator_tensor_results (dict): Dict of collaborator tensor + results. + collaborator_tasks_results (dict): Dict of collaborator tasks + results. + collaborator_task_weight (dict): Dict of collaborator task weight. + + .. note:: \* - plan setting. - """ + """ def __init__(self, - aggregator_uuid, federation_uuid, authorized_cols, - init_state_path, best_state_path, last_state_path, - assigner, straggler_handling_policy=None, rounds_to_train=256, @@ -80,18 +79,28 @@ def __init__(self, Args: aggregator_uuid (int): Aggregation ID. federation_uuid (str): Federation ID. - authorized_cols (list of str): The list of IDs of enrolled collaborators. + authorized_cols (list of str): The list of IDs of enrolled + collaborators. init_state_path (str): The location of the initial weight file. - best_state_path (str): The file location to store the weight of the best model. - last_state_path (str): The file location to store the latest weight. + best_state_path (str): The file location to store the weight of + the best model. + last_state_path (str): The file location to store the latest + weight. assigner: Assigner object. - straggler_handling_policy (optional): Straggler handling policy. Defaults to CutoffTimeBasedStragglerHandling. - rounds_to_train (int, optional): Number of rounds to train. Defaults to 256. - single_col_cert_common_name (str, optional): Common name for single collaborator certificate. Defaults to None. - compression_pipeline (optional): Compression pipeline. Defaults to NoCompressionPipeline. - db_store_rounds (int, optional): Rounds to store in TensorDB. Defaults to 1. - write_logs (bool, optional): Whether to write logs. Defaults to False. - log_metric_callback (optional): Callback for log metric. Defaults to None. + straggler_handling_policy (optional): Straggler handling policy. + Defaults to CutoffTimeBasedStragglerHandling. + rounds_to_train (int, optional): Number of rounds to train. + Defaults to 256. + single_col_cert_common_name (str, optional): Common name for single + collaborator certificate. Defaults to None. + compression_pipeline (optional): Compression pipeline. Defaults to + NoCompressionPipeline. + db_store_rounds (int, optional): Rounds to store in TensorDB. + Defaults to 1. + write_logs (bool, optional): Whether to write logs. Defaults to + False. + log_metric_callback (optional): Callback for log metric. Defaults + to None. **kwargs: Additional keyword arguments. """ self.round_number = 0 @@ -104,9 +113,8 @@ def __init__(self, # Cleaner solution? self.single_col_cert_common_name = '' - self.straggler_handling_policy = ( - straggler_handling_policy or CutoffTimeBasedStragglerHandling() - ) + self.straggler_handling_policy = (straggler_handling_policy or + CutoffTimeBasedStragglerHandling()) self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] @@ -155,7 +163,8 @@ def __init__(self, round_number=0, tensor_pipe=self.compression_pipeline) else: - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) + self.model: base_pb2.ModelProto = utils.load_proto( + self.init_state_path) self._load_initial_tensors() # keys are TensorKeys self.collaborator_tensor_results = {} # {TensorKey: nparray}} @@ -321,7 +330,8 @@ def get_tasks(self, collaborator_name): time_to_quit = False # otherwise, get the tasks from our task assigner - tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number) + tasks = self.assigner.get_tasks_for_collaborator( + collaborator_name, self.round_number) # if no tasks, tell the collaborator to sleep if len(tasks) == 0: @@ -380,15 +390,16 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, report (bool): Whether to report. tags (tuple[str, ...]): Tags. require_lossless (bool): Whether to require lossless. - + Returns: named_tensor (protobuf) : NamedTensor, the tensor requested by the collaborator. - + Raises: ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}. """ - self.logger.debug(f'Retrieving aggregated tensor {tensor_name},{round_number},{tags} ' - f'for collaborator {collaborator_name}') + self.logger.debug( + f'Retrieving aggregated tensor {tensor_name},{round_number},{tags} ' + f'for collaborator {collaborator_name}') if 'compressed' in tags or require_lossless: compress_lossless = True @@ -403,15 +414,13 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, if 'lossy_compressed' in tags: tags = change_tags(tags, remove_field='lossy_compressed') - tensor_key = TensorKey( - tensor_name, self.uuid, round_number, report, tags - ) + tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, + tags) tensor_name, origin, round_number, report, tags = tensor_key if 'aggregated' in tags and 'delta' in tags and round_number != 0: - agg_tensor_key = TensorKey( - tensor_name, origin, round_number, report, ('aggregated',) - ) + agg_tensor_key = TensorKey(tensor_name, origin, round_number, + report, ('aggregated', )) else: agg_tensor_key = tensor_key @@ -426,7 +435,9 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, break if nparray is None: - raise ValueError(f'Aggregator does not have an aggregated tensor for {tensor_key}') + raise ValueError( + f'Aggregator does not have an aggregated tensor for {tensor_key}' + ) # quite a bit happens in here, including compression, delta handling, # etc... @@ -435,8 +446,7 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, agg_tensor_key, nparray, send_model_deltas=True, - compress_lossless=compress_lossless - ) + compress_lossless=compress_lossless) return named_tensor @@ -444,7 +454,8 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, compress_lossless): """Construct the NamedTensor Protobuf. - Also includes logic to create delta, compress tensors with the TensorCodec, etc. + Also includes logic to create delta, compress tensors with the + TensorCodec, etc. Args: tensor_key (TensorKey): Tensor key. @@ -462,11 +473,8 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, if 'aggregated' in tags and send_model_deltas: # Should get the pretrained model to create the delta. If training # has happened, Model should already be stored in the TensorDB - model_tk = TensorKey(tensor_name, - origin, - round_number - 1, - report, - ('model',)) + model_tk = TensorKey(tensor_name, origin, round_number - 1, report, + ('model', )) model_nparray = self.tensor_db.get_tensor_from_cache(model_tk) @@ -474,57 +482,49 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, 'The original model layer should be present if the latest ' 'aggregated model is present') delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta( - tensor_key, - nparray, - model_nparray - ) + tensor_key, nparray, model_nparray) delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress( - delta_tensor_key, - delta_nparray, - lossless=compress_lossless - ) + delta_tensor_key, delta_nparray, lossless=compress_lossless) named_tensor = utils.construct_named_tensor( delta_comp_tensor_key, delta_comp_nparray, metadata, - lossless=compress_lossless - ) + lossless=compress_lossless) else: # Assume every other tensor requires lossless compression compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress( - tensor_key, - nparray, - require_lossless=True - ) + tensor_key, nparray, require_lossless=True) named_tensor = utils.construct_named_tensor( compressed_tensor_key, compressed_nparray, metadata, - lossless=compress_lossless - ) + lossless=compress_lossless) return named_tensor def _collaborator_task_completed(self, collaborator, task_name, round_num): """Check if the collaborator has completed the task for the round. - The aggregator doesn't actually know which tensors should be sent from the collaborator - so it must to rely specifically on the presence of previous results. + The aggregator doesn't actually know which tensors should be sent from + the collaborator so it must to rely specifically on the presence of + previous results. Args: - collaborator (str): Collaborator to check if their task has been completed. + collaborator (str): Collaborator to check if their task has been + completed. task_name (str): The name of the task (TaskRunner function). round_num (int): Round number. Returns: - bool: Whether or not the collaborator has completed the task for this round. + bool: Whether or not the collaborator has completed the task for + this round. """ task_key = TaskResultKey(task_name, collaborator, round_num) return task_key in self.collaborator_tasks_results - def send_local_task_results(self, collaborator_name, round_number, task_name, - data_size, named_tensors): + def send_local_task_results(self, collaborator_name, round_number, + task_name, data_size, named_tensors): """RPC called by collaborator. Transmits collaborator's task results to the aggregator. @@ -542,32 +542,27 @@ def send_local_task_results(self, collaborator_name, round_number, task_name, if self._time_to_quit() or self._is_task_done(task_name): self.logger.warning( f'STRAGGLER: Collaborator {collaborator_name} is reporting results ' - 'after task {task_name} has finished.' - ) + 'after task {task_name} has finished.') return if self.round_number != round_number: self.logger.warning( f'Collaborator {collaborator_name} is reporting results' - f' for the wrong round: {round_number}. Ignoring...' - ) + f' for the wrong round: {round_number}. Ignoring...') return self.logger.info( f'Collaborator {collaborator_name} is sending task results ' - f'for {task_name}, round {round_number}' - ) + f'for {task_name}, round {round_number}') task_key = TaskResultKey(task_name, collaborator_name, round_number) # we mustn't have results already - if self._collaborator_task_completed( - collaborator_name, task_name, round_number - ): + if self._collaborator_task_completed(collaborator_name, task_name, + round_number): raise ValueError( f'Aggregator already has task results from collaborator {collaborator_name}' - f' for task {task_key}' - ) + f' for task {task_key}') # By giving task_key it's own weight, we can support different # training/validation weights @@ -606,7 +601,8 @@ def send_local_task_results(self, collaborator_name, round_number, task_name, def _process_named_tensor(self, named_tensor, collaborator_name): """Extract the named tensor fields. - Performs decompression, delta computation, and inserts results into TensorDB. + Performs decompression, delta computation, and inserts results into + TensorDB. Args: named_tensor (protobuf NamedTensor): Named tensor. @@ -622,77 +618,73 @@ def _process_named_tensor(self, named_tensor, collaborator_name): The numpy array associated with the returned tensorkey. """ raw_bytes = named_tensor.data_bytes - metadata = [{'int_to_float': proto.int_to_float, - 'int_list': proto.int_list, - 'bool_list': proto.bool_list} - for proto in named_tensor.transformer_metadata] + metadata = [{ + 'int_to_float': proto.int_to_float, + 'int_list': proto.int_list, + 'bool_list': proto.bool_list + } for proto in named_tensor.transformer_metadata] # The tensor has already been transfered to aggregator, # so the newly constructed tensor should have the aggregator origin - tensor_key = TensorKey( - named_tensor.name, - self.uuid, - named_tensor.round_number, - named_tensor.report, - tuple(named_tensor.tags) - ) + tensor_key = TensorKey(named_tensor.name, self.uuid, + named_tensor.round_number, named_tensor.report, + tuple(named_tensor.tags)) tensor_name, origin, round_number, report, tags = tensor_key - assert ('compressed' in tags or 'lossy_compressed' in tags), ( - f'Named tensor {tensor_key} is not compressed' - ) + assert ('compressed' in tags or 'lossy_compressed' + in tags), (f'Named tensor {tensor_key} is not compressed') if 'compressed' in tags: dec_tk, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, transformer_metadata=metadata, - require_lossless=True - ) + require_lossless=True) dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk # Need to add the collaborator tag to the resulting tensor new_tags = change_tags(dec_tags, add_field=collaborator_name) # layer.agg.n.trained.delta.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) + decompressed_tensor_key = TensorKey(dec_name, dec_origin, + dec_round_num, dec_report, + new_tags) if 'lossy_compressed' in tags: dec_tk, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, transformer_metadata=metadata, - require_lossless=False - ) + require_lossless=False) dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk new_tags = change_tags(dec_tags, add_field=collaborator_name) # layer.agg.n.trained.delta.lossy_decompressed.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) + decompressed_tensor_key = TensorKey(dec_name, dec_origin, + dec_round_num, dec_report, + new_tags) if 'delta' in tags: - base_model_tensor_key = TensorKey( - tensor_name, origin, round_number, report, ('model',) - ) + base_model_tensor_key = TensorKey(tensor_name, origin, + round_number, report, + ('model', )) base_model_nparray = self.tensor_db.get_tensor_from_cache( - base_model_tensor_key - ) + base_model_tensor_key) if base_model_nparray is None: - raise ValueError(f'Base model {base_model_tensor_key} not present in TensorDB') + raise ValueError( + f'Base model {base_model_tensor_key} not present in TensorDB' + ) final_tensor_key, final_nparray = self.tensor_codec.apply_delta( - decompressed_tensor_key, - decompressed_nparray, base_model_nparray - ) + decompressed_tensor_key, decompressed_nparray, + base_model_nparray) else: final_tensor_key = decompressed_tensor_key final_nparray = decompressed_nparray - assert (final_nparray is not None), f'Could not create tensorkey {final_tensor_key}' + assert (final_nparray + is not None), f'Could not create tensorkey {final_tensor_key}' self.tensor_db.cache_tensor({final_tensor_key: final_nparray}) self.logger.debug(f'Created TensorKey: {final_tensor_key}') return final_tensor_key, final_nparray def _end_of_task_check(self, task_name): - """Check whether all collaborators who are supposed to perform the task complete. + """Check whether all collaborators who are supposed to perform the + task complete. Args: task_name (str): Task name. @@ -705,16 +697,16 @@ def _end_of_task_check(self, task_name): # now check for the end of the round self._end_of_round_check() - def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results): - """ - Prepare aggregated tensorkey tags. - - Args: - tensor_name (str): Tensor name. - origin: Origin. - round_number (int): Round number. - report (bool): Whether to report. - agg_results (np.array): Aggregated results. + def _prepare_trained(self, tensor_name, origin, round_number, report, + agg_results): + """Prepare aggregated tensorkey tags. + + Args: + tensor_name (str): Tensor name. + origin: Origin. + round_number (int): Round number. + report (bool): Whether to report. + agg_results (np.array): Aggregated results. """ # The aggregated tensorkey tags should have the form of # 'trained' or 'trained.lossy_decompressed' @@ -724,30 +716,18 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # First insert the aggregated model layer with the # correct tensorkey - agg_tag_tk = TensorKey( - tensor_name, - origin, - round_number + 1, - report, - ('aggregated',) - ) + agg_tag_tk = TensorKey(tensor_name, origin, round_number + 1, report, + ('aggregated', )) self.tensor_db.cache_tensor({agg_tag_tk: agg_results}) # Create delta and save it in TensorDB - base_model_tk = TensorKey( - tensor_name, - origin, - round_number, - report, - ('model',) - ) - base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk) + base_model_tk = TensorKey(tensor_name, origin, round_number, report, + ('model', )) + base_model_nparray = self.tensor_db.get_tensor_from_cache( + base_model_tk) if base_model_nparray is not None: delta_tk, delta_nparray = self.tensor_codec.generate_delta( - agg_tag_tk, - agg_results, - base_model_nparray - ) + agg_tag_tk, agg_results, base_model_nparray) else: # This condition is possible for base model # optimizer states (i.e. Adam/iter:0, SGD, etc.) @@ -757,8 +737,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Compress lossless/lossy compressed_delta_tk, compressed_delta_nparray, metadata = self.tensor_codec.compress( - delta_tk, delta_nparray - ) + delta_tk, delta_nparray) # TODO extend the TensorDB so that compressed data is # supported. Once that is in place @@ -767,21 +746,18 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Decompress lossless/lossy decompressed_delta_tk, decompressed_delta_nparray = self.tensor_codec.decompress( - compressed_delta_tk, - compressed_delta_nparray, - metadata - ) + compressed_delta_tk, compressed_delta_nparray, metadata) - self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray}) + self.tensor_db.cache_tensor( + {decompressed_delta_tk: decompressed_delta_nparray}) # Apply delta (unless delta couldn't be created) if base_model_nparray is not None: - self.logger.debug(f'Applying delta for layer {decompressed_delta_tk[0]}') + self.logger.debug( + f'Applying delta for layer {decompressed_delta_tk[0]}') new_model_tk, new_model_nparray = self.tensor_codec.apply_delta( - decompressed_delta_tk, - decompressed_delta_nparray, - base_model_nparray - ) + decompressed_delta_tk, decompressed_delta_nparray, + base_model_nparray) else: new_model_tk, new_model_nparray = decompressed_delta_tk, decompressed_delta_nparray @@ -790,13 +766,9 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Relabel the tags to 'model' (new_model_tensor_name, new_model_origin, new_model_round_number, new_model_report, new_model_tags) = new_model_tk - final_model_tk = TensorKey( - new_model_tensor_name, - new_model_origin, - new_model_round_number, - new_model_report, - ('model',) - ) + final_model_tk = TensorKey(new_model_tensor_name, new_model_origin, + new_model_round_number, new_model_report, + ('model', )) # Finally, cache the updated model tensor self.tensor_db.cache_tensor({final_model_tk: new_model_nparray}) @@ -812,12 +784,12 @@ def _compute_validation_related_task_metrics(self, task_name): # This handles getting the subset of collaborators that may be # part of the validation task all_collaborators_for_task = self.assigner.get_collaborators_for_task( - task_name, self.round_number - ) + task_name, self.round_number) # leave out stragglers for the round collaborators_for_task = [] for c in all_collaborators_for_task: - if self._collaborator_task_completed(c, task_name, self.round_number): + if self._collaborator_task_completed(c, task_name, + self.round_number): collaborators_for_task.append(c) # The collaborator data sizes for that task @@ -836,8 +808,10 @@ def _compute_validation_related_task_metrics(self, task_name): # collaborator in our subset, and apply the correct # transformations to the tensorkey to resolve the aggregated # tensor for that round - task_agg_function = self.assigner.get_aggregation_type_for_task(task_name) - task_key = TaskResultKey(task_name, collaborators_for_task[0], self.round_number) + task_agg_function = self.assigner.get_aggregation_type_for_task( + task_name) + task_key = TaskResultKey(task_name, collaborators_for_task[0], + self.round_number) for tensor_key in self.collaborator_tasks_results[task_key]: tensor_name, origin, round_number, report, tags = tensor_key @@ -845,11 +819,16 @@ def _compute_validation_related_task_metrics(self, task_name): f'Tensor {tensor_key} in task {task_name} has not been processed correctly' ) # Strip the collaborator label, and lookup aggregated tensor - new_tags = change_tags(tags, remove_field=collaborators_for_task[0]) - agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags) - agg_function = WeightedAverage() if 'metric' in tags else task_agg_function + new_tags = change_tags(tags, + remove_field=collaborators_for_task[0]) + agg_tensor_key = TensorKey(tensor_name, origin, round_number, + report, new_tags) + agg_function = WeightedAverage( + ) if 'metric' in tags else task_agg_function agg_results = self.tensor_db.get_aggregated_tensor( - agg_tensor_key, collaborator_weight_dict, aggregation_function=agg_function) + agg_tensor_key, + collaborator_weight_dict, + aggregation_function=agg_function) if report: # Caution: This schema must be followed. It is also used in @@ -869,12 +848,14 @@ def _compute_validation_related_task_metrics(self, task_name): if 'validate_agg' in tags: # Compare the accuracy of the model, potentially save it if self.best_model_score is None or self.best_model_score < agg_results: - self.logger.metric(f'Round {round_number}: saved the best ' - f'model with score {agg_results:f}') + self.logger.metric( + f'Round {round_number}: saved the best ' + f'model with score {agg_results:f}') self.best_model_score = agg_results self._save_model(round_number, self.best_state_path) if 'trained' in tags: - self._prepare_trained(tensor_name, origin, round_number, report, agg_results) + self._prepare_trained(tensor_name, origin, round_number, + report, agg_results) def _end_of_round_check(self): """Check if the round complete. @@ -889,7 +870,8 @@ def _end_of_round_check(self): Returns: None """ - if not self._is_round_done() or self._end_of_round_check_done[self.round_number]: + if not self._is_round_done() or self._end_of_round_check_done[ + self.round_number]: return # Compute all validation related metrics @@ -926,14 +908,12 @@ def _is_task_done(self, task_name): bool: Whether the task is done. """ all_collaborators = self.assigner.get_collaborators_for_task( - task_name, self.round_number - ) + task_name, self.round_number) collaborators_done = [] for c in all_collaborators: - if self._collaborator_task_completed( - c, task_name, self.round_number - ): + if self._collaborator_task_completed(c, task_name, + self.round_number): collaborators_done.append(c) straggler_check = self.straggler_handling_policy.straggler_cutoff_check( @@ -943,11 +923,14 @@ def _is_task_done(self, task_name): for c in all_collaborators: if c not in collaborators_done: self.stragglers.append(c) - self.logger.info(f'\tEnding task {task_name} early due to straggler cutoff policy') + self.logger.info( + f'\tEnding task {task_name} early due to straggler cutoff policy' + ) self.logger.warning(f'\tIdentified stragglers: {self.stragglers}') # all are done or straggler policy calls for early round end. - return straggler_check or len(all_collaborators) == len(collaborators_done) + return straggler_check or len(all_collaborators) == len( + collaborators_done) def _is_round_done(self): """Check that round is done. @@ -955,11 +938,11 @@ def _is_round_done(self): Returns: bool: Whether the round is done. """ - tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number) + tasks_for_round = self.assigner.get_all_tasks_for_round( + self.round_number) return all( - self._is_task_done( - task_name) for task_name in tasks_for_round) + self._is_task_done(task_name) for task_name in tasks_for_round) def _log_big_warning(self): """Warn user about single collaborator cert mode.""" @@ -967,8 +950,7 @@ def _log_big_warning(self): f'\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS' f' NOT PROPER PKI AND ' f'SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN' - f' WARNED!!!' - ) + f' WARNED!!!') def stop(self, failed_collaborator: str = None) -> None: """Stop aggregator execution. @@ -987,8 +969,11 @@ def stop(self, failed_collaborator: str = None) -> None: # This code does not actually send `quit` tasks to collaborators, # it just mimics it by filling arrays. - for collaborator_name in filter(lambda c: c != failed_collaborator, self.authorized_cols): - self.logger.info(f'Sending signal to collaborator {collaborator_name} to shutdown...') + for collaborator_name in filter(lambda c: c != failed_collaborator, + self.authorized_cols): + self.logger.info( + f'Sending signal to collaborator {collaborator_name} to shutdown...' + ) self.quit_job_sent_to.append(collaborator_name)