diff --git a/openfl/callbacks/callback_list.py b/openfl/callbacks/callback_list.py index 477a16986c..bc8890f4fb 100644 --- a/openfl/callbacks/callback_list.py +++ b/openfl/callbacks/callback_list.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from openfl.callbacks.callback import Callback from openfl.callbacks.memory_profiler import MemoryProfiler +from openfl.callbacks.metric_writer import MetricWriter class CallbackList(Callback): @@ -22,13 +23,14 @@ def __init__( self, callbacks: list, add_memory_profiler=False, + add_metric_writer=False, tensor_db=None, **params, ): super().__init__() self.callbacks = _flatten(callbacks) if callbacks else [] - self._add_default_callbacks(add_memory_profiler) + self._add_default_callbacks(add_memory_profiler, add_metric_writer) self.set_tensor_db(tensor_db) self.set_params(params) @@ -45,16 +47,25 @@ def set_tensor_db(self, tensor_db): for callback in self.callbacks: callback.set_tensor_db(tensor_db) - def _add_default_callbacks(self, add_memory_profiler): + def _add_default_callbacks(self, add_memory_profiler, add_metric_writer): + """Add default callbacks to callbacks list if not already present.""" self._memory_profiler = None + self._metric_writer = None + for cb in self.callbacks: if isinstance(cb, MemoryProfiler): self._memory_profiler = cb + if isinstance(cb, MetricWriter): + self._metric_writer = cb if add_memory_profiler and self._memory_profiler is None: self._memory_profiler = MemoryProfiler() self.callbacks.append(self._memory_profiler) + if add_metric_writer and self._metric_writer is None: + self._metric_writer = MetricWriter() + self.callbacks.append(self._metric_writer) + def on_round_begin(self, round_num: int, logs=None): for callback in self.callbacks: callback.on_round_begin(round_num, logs) diff --git a/openfl/callbacks/metric_writer.py b/openfl/callbacks/metric_writer.py new file mode 100644 index 0000000000..131807b909 --- /dev/null +++ b/openfl/callbacks/metric_writer.py @@ -0,0 +1,66 @@ +import json +import logging +import os + +from tensorboardX import SummaryWriter + +from openfl.callbacks.callback import Callback + +logger = logging.getLogger(__name__) + + +class MetricWriter(Callback): + """Log scalar metrics at the end of each round. + + Attributes: + log_dir: Path to write logs as lines of JSON. Defaults to `./logs`. + use_tensorboard: If True, writes scalar summaries to TensorBoard under `log_dir`. + """ + + def __init__(self, log_dir: str = "./logs/", use_tensorboard: bool = True): + super().__init__() + self.log_dir = log_dir + self.use_tensorboard = use_tensorboard + + self._log_file_handle = None + self._summary_writer = None + + def on_experiment_begin(self, logs=None): + """Open file handles for logging.""" + + if not self._log_file_handle: + self._log_file_handle = open( + os.path.join(self.log_dir, self.params["origin"] + "_metrics.txt"), "a" + ) + + if self.use_tensorboard: + self._summary_writer = SummaryWriter( + os.path.join(self.log_dir, self.params["origin"] + "_tensorboard") + ) + + def on_round_end(self, round_num: int, logs=None): + """Log metrics. + + Args: + round_num: The current round number. + logs: A key-value pair of scalar metrics. + """ + logs = logs or {} + logger.info(f"Round {round_num}: Metrics: {logs}") + + self._log_file_handle.write(json.dumps(logs) + "\n") + self._log_file_handle.flush() + + if self._summary_writer: + for key, value in logs.items(): + self._summary_writer.add_scalar(key, value, round_num) + self._summary_writer.flush() + + def on_experiment_end(self, logs=None): + """Cleanup.""" + if self._log_file_handle: + self._log_file_handle.close() + self._log_file_handle = None + + if self._summary_writer: + self._summary_writer.close()