From 51df6d6999e1a09bca8c05b9a0fe11fd6576d156 Mon Sep 17 00:00:00 2001 From: "Shah, Karan" Date: Wed, 4 Dec 2024 10:23:01 +0530 Subject: [PATCH] Revert "Remove tensorboardX and never-used log_metric code" and add fn calls Signed-off-by: Shah, Karan --- openfl/component/aggregator/aggregator.py | 35 ++++++++++++++++++++++- openfl/utilities/logs.py | 30 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 72296130b6..038a1e00d3 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -15,7 +15,7 @@ from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import base_pb2, utils from openfl.utilities import TaskResultKey, TensorKey, change_tags -from openfl.utilities.logs import get_memory_usage +from openfl.utilities.logs import get_memory_usage, write_metric class Aggregator: @@ -38,6 +38,8 @@ class Aggregator: tensor_db (TensorDB): Object for tensor database. db_store_rounds* (int): Rounds to store in TensorDB. logger: Object for logging. + write_logs (bool): Flag to enable log writing. + log_metric_callback: Callback for logging metrics. best_model_score (optional): Score of the best model. Defaults to None. metric_queue (queue.Queue): Queue for metrics. @@ -74,7 +76,9 @@ def __init__( single_col_cert_common_name=None, compression_pipeline=None, db_store_rounds=1, + write_logs=False, log_memory_usage=False, + log_metric_callback=None, **kwargs, ): """Initializes the Aggregator. @@ -100,6 +104,10 @@ def __init__( NoCompressionPipeline. db_store_rounds (int, optional): Rounds to store in TensorDB. Defaults to 1. + write_logs (bool, optional): Whether to write logs. Defaults to + False. + log_metric_callback (optional): Callback for log metric. Defaults + to None. **kwargs: Additional keyword arguments. """ self.round_number = 0 @@ -136,6 +144,15 @@ def __init__( # Gathered together logging-related objects self.logger = getLogger(__name__) + self.write_logs = write_logs + self.log_metric_callback = log_metric_callback + + if self.write_logs: + self.log_metric = write_metric + if self.log_metric_callback: + self.log_metric = log_metric_callback + self.logger.info("Using custom log metric: %s", self.log_metric) + self.best_model_score = None self.metric_queue = queue.Queue() @@ -647,6 +664,14 @@ def send_local_task_results( } self.metric_queue.put(metrics) self.logger.metric("%s", str(metrics)) + if self.write_logs: + self.log_metric( + collaborator_name, + task_name, + tensor_key.tensor_name, + float(value), + round_number, + ) task_results.append(tensor_key) @@ -921,6 +946,14 @@ def _compute_validation_related_task_metrics(self, task_name): self.metric_queue.put(metrics) self.logger.metric("%s", metrics) + if self.write_logs: + self.log_metric( + "aggregator", + task_name, + tensor_key.tensor_name, + float(agg_results), + round_number, + ) # FIXME: Configurable logic for min/max criteria in saving best. if "validate_agg" in tags: diff --git a/openfl/utilities/logs.py b/openfl/utilities/logs.py index 747e2165df..a17f0742fc 100644 --- a/openfl/utilities/logs.py +++ b/openfl/utilities/logs.py @@ -10,6 +10,36 @@ import psutil from rich.console import Console from rich.logging import RichHandler +from tensorboardX import SummaryWriter + +writer = None + + +def get_writer(): + """Create global writer object. + + This function creates a global `SummaryWriter` object for logging to + TensorBoard. + """ + global writer + if not writer: + writer = SummaryWriter("./logs/tensorboard", flush_secs=5) + + +def write_metric(node_name, task_name, metric_name, metric, round_number): + """Write metric callback. + + This function logs a metric to TensorBoard. + + Args: + node_name (str): The name of the node. + task_name (str): The name of the task. + metric_name (str): The name of the metric. + metric (float): The value of the metric. + round_number (int): The current round number. + """ + get_writer() + writer.add_scalar(f"{node_name}/{task_name}/{metric_name}", metric, round_number) def setup_loggers(log_level=logging.INFO):