Skip to content

Commit

Permalink
Revert "Remove tensorboardX and never-used log_metric code" and add f…
Browse files Browse the repository at this point in the history
…n calls

Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed Dec 4, 2024
1 parent 6017900 commit 51df6d6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
35 changes: 34 additions & 1 deletion openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.utilities import TaskResultKey, TensorKey, change_tags
from openfl.utilities.logs import get_memory_usage
from openfl.utilities.logs import get_memory_usage, write_metric


class Aggregator:
Expand All @@ -38,6 +38,8 @@ class Aggregator:
tensor_db (TensorDB): Object for tensor database.
db_store_rounds* (int): Rounds to store in TensorDB.
logger: Object for logging.
write_logs (bool): Flag to enable log writing.
log_metric_callback: Callback for logging metrics.
best_model_score (optional): Score of the best model. Defaults to
None.
metric_queue (queue.Queue): Queue for metrics.
Expand Down Expand Up @@ -74,7 +76,9 @@ def __init__(
single_col_cert_common_name=None,
compression_pipeline=None,
db_store_rounds=1,
write_logs=False,
log_memory_usage=False,
log_metric_callback=None,
**kwargs,
):
"""Initializes the Aggregator.
Expand All @@ -100,6 +104,10 @@ def __init__(
NoCompressionPipeline.
db_store_rounds (int, optional): Rounds to store in TensorDB.
Defaults to 1.
write_logs (bool, optional): Whether to write logs. Defaults to
False.
log_metric_callback (optional): Callback for log metric. Defaults
to None.
**kwargs: Additional keyword arguments.
"""
self.round_number = 0
Expand Down Expand Up @@ -136,6 +144,15 @@ def __init__(

# Gathered together logging-related objects
self.logger = getLogger(__name__)
self.write_logs = write_logs
self.log_metric_callback = log_metric_callback

if self.write_logs:
self.log_metric = write_metric
if self.log_metric_callback:
self.log_metric = log_metric_callback
self.logger.info("Using custom log metric: %s", self.log_metric)

self.best_model_score = None
self.metric_queue = queue.Queue()

Expand Down Expand Up @@ -647,6 +664,14 @@ def send_local_task_results(
}
self.metric_queue.put(metrics)
self.logger.metric("%s", str(metrics))
if self.write_logs:
self.log_metric(
collaborator_name,
task_name,
tensor_key.tensor_name,
float(value),
round_number,
)

task_results.append(tensor_key)

Expand Down Expand Up @@ -921,6 +946,14 @@ def _compute_validation_related_task_metrics(self, task_name):

self.metric_queue.put(metrics)
self.logger.metric("%s", metrics)
if self.write_logs:
self.log_metric(
"aggregator",
task_name,
tensor_key.tensor_name,
float(agg_results),
round_number,
)

# FIXME: Configurable logic for min/max criteria in saving best.
if "validate_agg" in tags:
Expand Down
30 changes: 30 additions & 0 deletions openfl/utilities/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@
import psutil
from rich.console import Console
from rich.logging import RichHandler
from tensorboardX import SummaryWriter

writer = None


def get_writer():
"""Create global writer object.
This function creates a global `SummaryWriter` object for logging to
TensorBoard.
"""
global writer
if not writer:
writer = SummaryWriter("./logs/tensorboard", flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback.
This function logs a metric to TensorBoard.
Args:
node_name (str): The name of the node.
task_name (str): The name of the task.
metric_name (str): The name of the metric.
metric (float): The value of the metric.
round_number (int): The current round number.
"""
get_writer()
writer.add_scalar(f"{node_name}/{task_name}/{metric_name}", metric, round_number)


def setup_loggers(log_level=logging.INFO):
Expand Down

0 comments on commit 51df6d6

Please sign in to comment.