Skip to content

Commit

Permalink
Add metric_writer
Browse files Browse the repository at this point in the history
Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed Dec 21, 2024
1 parent e8894d6 commit aab8baf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
15 changes: 13 additions & 2 deletions openfl/callbacks/callback_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from openfl.callbacks.callback import Callback
from openfl.callbacks.memory_profiler import MemoryProfiler
from openfl.callbacks.metric_writer import MetricWriter


class CallbackList(Callback):
Expand All @@ -22,13 +23,14 @@ def __init__(
self,
callbacks: list,
add_memory_profiler=False,
add_metric_writer=False,
tensor_db=None,
**params,
):
super().__init__()
self.callbacks = _flatten(callbacks) if callbacks else []

self._add_default_callbacks(add_memory_profiler)
self._add_default_callbacks(add_memory_profiler, add_metric_writer)

self.set_tensor_db(tensor_db)
self.set_params(params)
Expand All @@ -45,16 +47,25 @@ def set_tensor_db(self, tensor_db):
for callback in self.callbacks:
callback.set_tensor_db(tensor_db)

def _add_default_callbacks(self, add_memory_profiler):
def _add_default_callbacks(self, add_memory_profiler, add_metric_writer):
"""Add default callbacks to callbacks list if not already present."""
self._memory_profiler = None
self._metric_writer = None

for cb in self.callbacks:
if isinstance(cb, MemoryProfiler):
self._memory_profiler = cb
if isinstance(cb, MetricWriter):
self._metric_writer = cb

if add_memory_profiler and self._memory_profiler is None:
self._memory_profiler = MemoryProfiler()
self.callbacks.append(self._memory_profiler)

if add_metric_writer and self._metric_writer is None:
self._metric_writer = MetricWriter()
self.callbacks.append(self._metric_writer)

def on_round_begin(self, round_num: int, logs=None):
for callback in self.callbacks:
callback.on_round_begin(round_num, logs)
Expand Down
66 changes: 66 additions & 0 deletions openfl/callbacks/metric_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
import logging
import os

from tensorboardX import SummaryWriter

from openfl.callbacks.callback import Callback

logger = logging.getLogger(__name__)


class MetricWriter(Callback):
"""Log scalar metrics at the end of each round.
Attributes:
log_dir: Path to write logs as lines of JSON. Defaults to `./logs`.
use_tensorboard: If True, writes scalar summaries to TensorBoard under `log_dir`.
"""

def __init__(self, log_dir: str = "./logs/", use_tensorboard: bool = True):
super().__init__()
self.log_dir = log_dir
self.use_tensorboard = use_tensorboard

self._log_file_handle = None
self._summary_writer = None

def on_experiment_begin(self, logs=None):
"""Open file handles for logging."""

if not self._log_file_handle:
self._log_file_handle = open(
os.path.join(self.log_dir, self.params["origin"] + "_metrics.txt"), "a"
)

if self.use_tensorboard:
self._summary_writer = SummaryWriter(
os.path.join(self.log_dir, self.params["origin"] + "_tensorboard")
)

def on_round_end(self, round_num: int, logs=None):
"""Log metrics.
Args:
round_num: The current round number.
logs: A key-value pair of scalar metrics.
"""
logs = logs or {}
logger.info(f"Round {round_num}: Metrics: {logs}")

self._log_file_handle.write(json.dumps(logs) + "\n")
self._log_file_handle.flush()

if self._summary_writer:
for key, value in logs.items():
self._summary_writer.add_scalar(key, value, round_num)
self._summary_writer.flush()

def on_experiment_end(self, logs=None):
"""Cleanup."""
if self._log_file_handle:
self._log_file_handle.close()
self._log_file_handle = None

if self._summary_writer:
self._summary_writer.close()

0 comments on commit aab8baf

Please sign in to comment.