Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce callbacks API #1195

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6fbbd77
Get rid of kwargs
MasterSkepticista Dec 5, 2024
fe5011a
Use module-level logger
MasterSkepticista Dec 5, 2024
edca03a
Reduce keras verbosity
MasterSkepticista Dec 5, 2024
31ad8ac
Remove all log_metric and log_memory_usage traces; add callback hooks
MasterSkepticista Dec 5, 2024
8857f03
Add `openfl.callbacks` module
MasterSkepticista Dec 5, 2024
49948cf
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 6, 2024
fc6b2cb
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 6, 2024
975e4ac
Include round_num for task callbacks
MasterSkepticista Dec 6, 2024
1267bf0
Add tensordb to callbacks
MasterSkepticista Dec 6, 2024
07e593d
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 9, 2024
2c84c88
No round_num on task callbacks
MasterSkepticista Dec 9, 2024
0a380c4
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 9, 2024
8d1aea3
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 10, 2024
e74b9bf
Remove task boundary callbacks
MasterSkepticista Dec 10, 2024
d63ced5
Remove tb/model_ckpt. Add memory_profiler
MasterSkepticista Dec 10, 2024
9ce8983
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 10, 2024
e103d63
Restore psutil and tbX
MasterSkepticista Dec 10, 2024
abb15da
Format code
MasterSkepticista Dec 10, 2024
5a2cd4b
Define default callbacks
MasterSkepticista Dec 10, 2024
23b8eb3
Add write_logs for bwd compat
MasterSkepticista Dec 10, 2024
4e32632
Add log_metric_callback for bwd compat
MasterSkepticista Dec 10, 2024
b501527
Migrate to module-level logger for collaborator
MasterSkepticista Dec 11, 2024
f5ebd1d
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 11, 2024
af0c40f
Review comments
MasterSkepticista Dec 11, 2024
2ed63ef
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 16, 2024
e8894d6
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista Dec 21, 2024
aab8baf
Add metric_writer
MasterSkepticista Dec 21, 2024
3c5e525
Add collaborator side metric logging
MasterSkepticista Dec 21, 2024
fc76f18
Make log dirs on exp begin
MasterSkepticista Dec 21, 2024
c4eb30b
Do not print use_tls
MasterSkepticista Dec 21, 2024
9b3da0e
Assume reportable metric to be a scalar
MasterSkepticista Dec 21, 2024
1cffb9d
Add aggregator side callbacks
MasterSkepticista Dec 21, 2024
afd1bee
do_task test returns mock dict
MasterSkepticista Dec 21, 2024
25c00f1
Consistency changes
MasterSkepticista Dec 21, 2024
efbfdc7
Add documentation hooks
MasterSkepticista Dec 21, 2024
e7068e1
Update docstring
MasterSkepticista Dec 21, 2024
be06eda
Update docs hook
MasterSkepticista Dec 21, 2024
38f7c30
Remove all traces of log_metric_callback and write_metric
MasterSkepticista Dec 21, 2024
d01ec5e
Do on_round_begin if not time_to_quit
MasterSkepticista Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
template : openfl.component.Aggregator
settings :
db_store_rounds : 2
write_logs : true
6 changes: 6 additions & 0 deletions openfl/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openfl.callbacks.callback import Callback
from openfl.callbacks.callback_list import CallbackList
from openfl.callbacks.lambda_callback import LambdaCallback
from openfl.callbacks.memory_profiler import MemoryProfiler
57 changes: 57 additions & 0 deletions openfl/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Callbacks API."""


class Callback:
"""Base class for callbacks.

Callbacks can be used to perform actions at different stages of the
Federated Learning process. To create a custom callback, subclass
`openfl.callbacks.Callback` and implement the necessary methods.

Callbacks can be triggered on the aggregator and collaborator side
for the following events:
* At the beginning of an experiment
* At the beginning of a round
* At the end of a round
* At the end of an experiment

Attributes:
params: Additional parameters saved for use within the callback.
tensor_db: The `TensorDB` instance of the respective participant.
"""

def __init__(self):
self.params = None
self.tensor_db = None

def set_params(self, params):
self.params = params

def set_tensor_db(self, tensor_db):
self.tensor_db = tensor_db

def on_round_begin(self, round_num: int, logs=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my read through, I see that the logs named-arg is only used by MetricWriter.
For extensibility, would it make more sense to have a **kwargs instead?
We would still need to consider how these "dynamic" named args are sent by the aggregator/collaborator but this could be maybe thought of in subsequent/future changes to the callbacks API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For extensibility, would it make more sense to have a **kwargs instead?

Generally speaking, **kwargs should be avoided. To pass information that specific callbacks would need, we have params that are declared at the time of instantiating callbacks list. This makes it clear what attributes callbacks will have access to, internally. Dynamic setting makes it harder to inspect and test code.

tensor_db: Optional `TensorDB` instance of the respective participant.
If provided, callbacks can access TensorDB for various actions.
params: Additional parameters saved for use within the callbacks.

Currently, logs are the default because most callbacks generally trigger actions based on metrics - say writing logs to file, trigger saving a checkpoint if logs["loss"] < 0.2, write logs to tensorboard etc.

That said, I agree that we can expand this argument list further - if more callbacks benefit from them.

"""Callback function to be executed at the beginning of a round.

Subclasses need to implement actions to be taken here.
"""

def on_round_end(self, round_num: int, logs=None):
"""Callback function to be executed at the end of a round.

Subclasses need to implement actions to be taken here.
"""

def on_experiment_begin(self, logs=None):
"""Callback function to be executed at the beginning of an experiment.

Subclasses need to implement actions to be taken here.
"""

def on_experiment_end(self, logs=None):
"""Callback function to be executed at the end of an experiment.

Subclasses need to implement actions to be taken here.
"""
83 changes: 83 additions & 0 deletions openfl/callbacks/callback_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import List

from openfl.callbacks.callback import Callback
from openfl.callbacks.memory_profiler import MemoryProfiler


class CallbackList(Callback):
"""An ensemble of callbacks.

This class allows multiple callbacks to be used together, by sequentially
calling each callback's respective methods.

Attributes:
callbacks: A list of `openfl.callbacks.Callback` instances.
add_memory_profiler: If True, adds a `MemoryProfiler` callback to the list.
tensor_db: Optional `TensorDB` instance of the respective participant.
If provided, callbacks can access TensorDB for various actions.
params: Additional parameters saved for use within the callbacks.
"""

def __init__(
self,
callbacks: List[Callback],
theakshaypant marked this conversation as resolved.
Show resolved Hide resolved
add_memory_profiler=False,
tensor_db=None,
**params,
):
super().__init__()
self.callbacks = _flatten(callbacks) if callbacks else []
theakshaypant marked this conversation as resolved.
Show resolved Hide resolved

self._add_default_callbacks(add_memory_profiler)

self.set_tensor_db(tensor_db)
self.set_params(params)

def set_params(self, params):
self.params = params
if params:
for callback in self.callbacks:
callback.set_params(params)
MasterSkepticista marked this conversation as resolved.
Show resolved Hide resolved

def set_tensor_db(self, tensor_db):
self.tensor_db = tensor_db
if tensor_db:
for callback in self.callbacks:
callback.set_tensor_db(tensor_db)

def _add_default_callbacks(self, add_memory_profiler):
self._memory_profiler = None
for cb in self.callbacks:
if isinstance(cb, MemoryProfiler):
self._memory_profiler = cb
MasterSkepticista marked this conversation as resolved.
Show resolved Hide resolved

if add_memory_profiler and self._memory_profiler is None:
MasterSkepticista marked this conversation as resolved.
Show resolved Hide resolved
self._memory_profiler = MemoryProfiler()
self.callbacks.append(self._memory_profiler)

def on_round_begin(self, round_num: int, logs=None):
for callback in self.callbacks:
callback.on_round_begin(round_num, logs)

def on_round_end(self, round_num: int, logs=None):
for callback in self.callbacks:
callback.on_round_end(round_num, logs)

def on_experiment_begin(self, logs=None):
for callback in self.callbacks:
callback.on_experiment_begin(logs)

def on_experiment_end(self, logs=None):
for callback in self.callbacks:
callback.on_experiment_end(logs)


def _flatten(l):
"""Flatten a possibly-nested tree of lists."""
for elem in l:
if isinstance(elem, list):
yield from _flatten(elem)
else:
yield elem
38 changes: 38 additions & 0 deletions openfl/callbacks/lambda_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openfl.callbacks.callback import Callback


class LambdaCallback(Callback):
"""Custom on-the-fly callbacks.

This callback can be constructed with functions that will be called
at the appropriate time during the life-cycle of a Federated Learning experiment.
Certain callbacks may expect positional arguments, for example:

* on_round_begin: expects `round_num` as a positional argument.
* on_round_end: expects `round_num` as a positional argument.

Args:
on_round_begin: called at the beginning of every round.
on_round_end: called at the end of every round.
on_experiment_begin: called at the beginning of an experiment.
on_experiment_end: called at the end of an experiment.
"""

def __init__(
self,
on_round_begin=None,
on_round_end=None,
on_experiment_begin=None,
on_experiment_end=None,
):
super().__init__()
if on_round_begin is not None:
self.on_round_begin = on_round_begin
if on_round_end is not None:
self.on_round_end = on_round_end
if on_experiment_begin is not None:
self.on_experiment_begin = on_experiment_begin
if on_experiment_end is not None:
self.on_experiment_end = on_experiment_end
64 changes: 64 additions & 0 deletions openfl/callbacks/memory_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Memory Profiler callback."""

import json
import logging
import os

import psutil

from openfl.callbacks.callback import Callback

logger = logging.getLogger(__name__)


class MemoryProfiler(Callback):
"""Profile memory usage of the current process at the end of each round.

Attributes:
log_dir: If set, writes logs as lines of JSON.
"""

def __init__(self, log_dir: str = "./logs/"):
super().__init__()
self.log_dir = None
if log_dir:
os.makedirs(log_dir, exist_ok=True)
self.log_dir = log_dir

def on_round_end(self, round_num: int, logs=None):
origin = self.params["origin"]

info = _get_memory_usage()
info["round_number"] = round_num
info["origin"] = origin

logger.info(f"Round {round_num}: Memory usage: {info}")
if self.log_dir:
with open(os.path.join(self.log_dir, f"{origin}_memory_usage.json"), "a") as f:
f.write(json.dumps(info) + "\n")


def _get_memory_usage() -> dict:
process = psutil.Process(os.getpid())
virtual_memory = psutil.virtual_memory()
swap_memory = psutil.swap_memory()
info = {
"process_memory": round(process.memory_info().rss / (1024**2), 2),
"virtual_memory/total": round(virtual_memory.total / (1024**2), 2),
"virtual_memory/available": round(virtual_memory.available / (1024**2), 2),
"virtual_memory/percent": virtual_memory.percent,
"virtual_memory/used": round(virtual_memory.used / (1024**2), 2),
"virtual_memory/free": round(virtual_memory.free / (1024**2), 2),
"virtual_memory/active": round(virtual_memory.active / (1024**2), 2),
"virtual_memory/inactive": round(virtual_memory.inactive / (1024**2), 2),
"virtual_memory/buffers": round(virtual_memory.buffers / (1024**2), 2),
"virtual_memory/cached": round(virtual_memory.cached / (1024**2), 2),
"virtual_memory/shared": round(virtual_memory.shared / (1024**2), 2),
"swap_memory/total": round(swap_memory.total / (1024**2), 2),
"swap_memory/used": round(swap_memory.used / (1024**2), 2),
"swap_memory/free": round(swap_memory.free / (1024**2), 2),
"swap_memory/percent": swap_memory.percent,
}
return info
Loading
Loading