diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 038a1e00d3d..041022e3007 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -79,7 +79,7 @@ def __init__( write_logs=False, log_memory_usage=False, log_metric_callback=None, - **kwargs, + initial_tensor_dict=None, ): """Initializes the Aggregator. @@ -110,15 +110,18 @@ def __init__( to None. **kwargs: Additional keyword arguments. """ + self.logger = getLogger(__name__) self.round_number = 0 - self.single_col_cert_common_name = single_col_cert_common_name - if self.single_col_cert_common_name is not None: - self._log_big_warning() - else: - # FIXME: '' instead of None is just for protobuf compatibility. - # Cleaner solution? - self.single_col_cert_common_name = "" + if single_col_cert_common_name: + self.logger.warning( + "You are running in single collaborator certificate mode. " + "This mode is intended for development settings only and does not " + "provide proper Public Key Infrastructure (PKI) security. " + "Please use this mode with caution." + ) + # FIXME: "" instead of None is for protobuf compatibility. + self.single_col_cert_common_name = single_col_cert_common_name or "" self.straggler_handling_policy = ( straggler_handling_policy or CutoffTimeBasedStragglerHandling() @@ -143,7 +146,6 @@ def __init__( self.db_store_rounds = db_store_rounds # Gathered together logging-related objects - self.logger = getLogger(__name__) self.write_logs = write_logs self.log_metric_callback = log_metric_callback @@ -166,10 +168,10 @@ def __init__( self.best_tensor_dict: dict = {} self.last_tensor_dict: dict = {} - if kwargs.get("initial_tensor_dict", None) is not None: - self._load_initial_tensors_from_dict(kwargs["initial_tensor_dict"]) + if initial_tensor_dict: + self._load_initial_tensors_from_dict(initial_tensor_dict) self.model = utils.construct_model_proto( - tensor_dict=kwargs["initial_tensor_dict"], + tensor_dict=initial_tensor_dict, round_number=0, tensor_pipe=self.compression_pipeline, ) @@ -326,9 +328,7 @@ def _time_to_quit(self): Returns: bool: True if it's time to quit, False otherwise. """ - if self.round_number >= self.rounds_to_train: - return True - return False + return self.round_number >= self.rounds_to_train def get_tasks(self, collaborator_name): """RPC called by a collaborator to determine which tasks to perform. @@ -1056,19 +1056,10 @@ def _is_collaborator_done(self, collaborator_name: str, round_number: int) -> No if all_tasks_completed: self.collaborators_done.append(collaborator_name) self.logger.info( - f"Round: {self.round_number}, Collaborators that have completed all tasks: " + f"Round {self.round_number}: Collaborators that have completed all tasks: " f"{self.collaborators_done}" ) - def _log_big_warning(self): - """Warn user about single collaborator cert mode.""" - self.logger.warning( - f"\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS" - f" NOT PROPER PKI AND " - f"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN" - f" WARNED!!!" - ) - def stop(self, failed_collaborator: str = None) -> None: """Stop aggregator execution. @@ -1092,76 +1083,3 @@ def stop(self, failed_collaborator: str = None) -> None: collaborator_name, ) self.quit_job_sent_to.append(collaborator_name) - - -the_dragon = """ - - ,@@.@@+@@##@,@@@@.`@@#@+ *@@@@ #@##@ `@@#@# @@@@@ @@ @@@@` #@@@ :@@ `@#`@@@#.@ - @@ #@ ,@ +. @@.@* #@ :` @+*@ .@`+. @@ *@::@`@@ @@# @@ #`;@`.@@ @@@`@`#@* +:@` - @@@@@ ,@@@ @@@@ +@@+ @@@@ .@@@ @@ .@+:@@@: .;+@` @@ ,;,#@` @@ @@@@@ ,@@@* @ - @@ #@ ,@`*. @@.@@ #@ ,; `@+,@#.@.*` @@ ,@::@`@@` @@@@# @@`:@;*@+ @@ @`:@@`@ *@@ ` - .@@`@@,+@+;@.@@ @@`@@;*@ ;@@#@:*@+;@ `@@;@@ #@**@+;@ `@@:`@@@@ @@@@.`@+ .@ +@+@*,@ - `` `` ` `` . ` ` ` ` ` .` ` `` `` `` ` . ` - - - - .** - ;` `****: - @**`******* - *** +***********; - ,@***;` .*:,;************ - ;***********@@*********** - ;************************, - `************************* - ************************* - ,************************ - **#********************* - *@****` :**********; - +**; .********. - ;*; `*******#: `,: - ****@@@++:: ,,;***. - *@@@**;#;: +: **++*, - @***#@@@: +*; ,**** - @*@+**** ***` ****, - ,@#******. , **** **;,**. - * ******** :, ;*:*+ ** :,** - # ********:: *,.*:**` * ,*; - . *********: .+,*:;*: : `:** - ; :********: ***::** ` ` ** - + :****::*** , *;;::**` :* - `` .****::;**::: *;::::*; ;* - * *****::***:. **::::** ;: - # *****;:**** ;*::;*** ,*` - ; ************` ,**:****; ::* - : *************;:;*;*++: *. - : *****************;* `* - `. `*****************; : *. - .` .*+************+****;: :* - `. :;+***********+******;` : .,* - ; ::*+*******************. `:: .`:. - + :::**********************;;:` * - + ,::;*************;:::*******. * - # `:::+*************:::;******** :, * - @ :::***************;:;*********;:, * - @ ::::******:*********************: ,:* - @ .:::******:;*********************, :* - # :::******::******###@*******;;**** *, - # .::;*****::*****#****@*****;:::***; `` ** - * ::;***********+*****+#******::*****,,,,** - : :;***********#******#****************** - .` `;***********#******+****+************ - `, ***#**@**+***+*****+**************;` - ; *++**#******#+****+` `.,.. - + `@***#*******#****# - + +***@********+**+: - * .+**+;**;;;**;#**# - ,` ****@ +*+: - # +**+ :+** - @ ;**+, ,***+ - # #@+**** *#****+ - `; @+***+@ `#**+#++ - # #*#@##, .++:.,# - `* @# +. - @@@ - # `@ - , """ diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 54683135d53..119efd5eeee 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -82,7 +82,6 @@ def __init__( compression_pipeline=None, db_store_rounds=1, log_memory_usage=False, - **kwargs, ): """Initialize the Collaborator object. @@ -104,7 +103,6 @@ def __init__( Defaults to None. db_store_rounds (int, optional): The number of rounds to store in the database. Defaults to 1. - **kwargs: Variable length argument list. """ self.single_col_cert_common_name = None diff --git a/openfl/federated/task/runner_keras.py b/openfl/federated/task/runner_keras.py index b33b20fde57..2d992c83bb9 100644 --- a/openfl/federated/task/runner_keras.py +++ b/openfl/federated/task/runner_keras.py @@ -194,7 +194,7 @@ def train_iteration(self, batch_generator, metrics: list = None, **kwargs): f"Param_metrics = {metrics}, model_metrics_names = {model_metrics_names}" ) - history = self.model.fit(batch_generator, verbose=1, **kwargs) + history = self.model.fit(batch_generator, verbose=2, **kwargs) results = [] for metric in metrics: value = np.mean([history.history[metric]]) @@ -224,7 +224,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): self.rebuild_model(round_num, input_tensor_dict, validation=True) param_metrics = kwargs["metrics"] - vals = self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1) + vals = self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=2) model_metrics_names = self.model.metrics_names if type(vals) is not list: vals = [vals]