-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce callbacks
API
#1195
Merged
psfoley
merged 39 commits into
securefederatedai:develop
from
MasterSkepticista:karansh1/callbacks_api
Dec 23, 2024
Merged
Introduce callbacks
API
#1195
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
6fbbd77
Get rid of kwargs
MasterSkepticista fe5011a
Use module-level logger
MasterSkepticista edca03a
Reduce keras verbosity
MasterSkepticista 31ad8ac
Remove all log_metric and log_memory_usage traces; add callback hooks
MasterSkepticista 8857f03
Add `openfl.callbacks` module
MasterSkepticista 49948cf
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista fc6b2cb
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista 975e4ac
Include round_num for task callbacks
MasterSkepticista 1267bf0
Add tensordb to callbacks
MasterSkepticista 07e593d
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista 2c84c88
No round_num on task callbacks
MasterSkepticista 0a380c4
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista 8d1aea3
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista e74b9bf
Remove task boundary callbacks
MasterSkepticista d63ced5
Remove tb/model_ckpt. Add memory_profiler
MasterSkepticista 9ce8983
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista e103d63
Restore psutil and tbX
MasterSkepticista abb15da
Format code
MasterSkepticista 5a2cd4b
Define default callbacks
MasterSkepticista 23b8eb3
Add write_logs for bwd compat
MasterSkepticista 4e32632
Add log_metric_callback for bwd compat
MasterSkepticista b501527
Migrate to module-level logger for collaborator
MasterSkepticista f5ebd1d
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista af0c40f
Review comments
MasterSkepticista 2ed63ef
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista e8894d6
Merge branch 'develop' into karansh1/callbacks_api
MasterSkepticista aab8baf
Add metric_writer
MasterSkepticista 3c5e525
Add collaborator side metric logging
MasterSkepticista fc76f18
Make log dirs on exp begin
MasterSkepticista c4eb30b
Do not print use_tls
MasterSkepticista 9b3da0e
Assume reportable metric to be a scalar
MasterSkepticista 1cffb9d
Add aggregator side callbacks
MasterSkepticista afd1bee
do_task test returns mock dict
MasterSkepticista 25c00f1
Consistency changes
MasterSkepticista efbfdc7
Add documentation hooks
MasterSkepticista e7068e1
Update docstring
MasterSkepticista be06eda
Update docs hook
MasterSkepticista 38f7c30
Remove all traces of log_metric_callback and write_metric
MasterSkepticista d01ec5e
Do on_round_begin if not time_to_quit
MasterSkepticista File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
template : openfl.component.Aggregator | ||
settings : | ||
db_store_rounds : 2 | ||
write_logs : true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright 2020-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from openfl.callbacks.callback import Callback | ||
from openfl.callbacks.callback_list import CallbackList | ||
from openfl.callbacks.lambda_callback import LambdaCallback | ||
from openfl.callbacks.memory_profiler import MemoryProfiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2020-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Callbacks API.""" | ||
|
||
|
||
class Callback: | ||
"""Base class for callbacks. | ||
|
||
Callbacks can be used to perform actions at different stages of the | ||
Federated Learning process. To create a custom callback, subclass | ||
`openfl.callbacks.Callback` and implement the necessary methods. | ||
|
||
Callbacks can be triggered on the aggregator and collaborator side | ||
for the following events: | ||
* At the beginning of an experiment | ||
* At the beginning of a round | ||
* At the end of a round | ||
* At the end of an experiment | ||
|
||
Attributes: | ||
params: Additional parameters saved for use within the callback. | ||
tensor_db: The `TensorDB` instance of the respective participant. | ||
""" | ||
|
||
def __init__(self): | ||
self.params = None | ||
self.tensor_db = None | ||
|
||
def set_params(self, params): | ||
self.params = params | ||
|
||
def set_tensor_db(self, tensor_db): | ||
self.tensor_db = tensor_db | ||
|
||
def on_round_begin(self, round_num: int, logs=None): | ||
"""Callback function to be executed at the beginning of a round. | ||
|
||
Subclasses need to implement actions to be taken here. | ||
""" | ||
|
||
def on_round_end(self, round_num: int, logs=None): | ||
"""Callback function to be executed at the end of a round. | ||
|
||
Subclasses need to implement actions to be taken here. | ||
""" | ||
|
||
def on_experiment_begin(self, logs=None): | ||
"""Callback function to be executed at the beginning of an experiment. | ||
|
||
Subclasses need to implement actions to be taken here. | ||
""" | ||
|
||
def on_experiment_end(self, logs=None): | ||
"""Callback function to be executed at the end of an experiment. | ||
|
||
Subclasses need to implement actions to be taken here. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2020-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import List | ||
|
||
from openfl.callbacks.callback import Callback | ||
from openfl.callbacks.memory_profiler import MemoryProfiler | ||
|
||
|
||
class CallbackList(Callback): | ||
"""An ensemble of callbacks. | ||
|
||
This class allows multiple callbacks to be used together, by sequentially | ||
calling each callback's respective methods. | ||
|
||
Attributes: | ||
callbacks: A list of `openfl.callbacks.Callback` instances. | ||
add_memory_profiler: If True, adds a `MemoryProfiler` callback to the list. | ||
tensor_db: Optional `TensorDB` instance of the respective participant. | ||
If provided, callbacks can access TensorDB for various actions. | ||
params: Additional parameters saved for use within the callbacks. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
callbacks: List[Callback], | ||
theakshaypant marked this conversation as resolved.
Show resolved
Hide resolved
|
||
add_memory_profiler=False, | ||
tensor_db=None, | ||
**params, | ||
): | ||
super().__init__() | ||
self.callbacks = _flatten(callbacks) if callbacks else [] | ||
theakshaypant marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._add_default_callbacks(add_memory_profiler) | ||
|
||
self.set_tensor_db(tensor_db) | ||
self.set_params(params) | ||
|
||
def set_params(self, params): | ||
self.params = params | ||
if params: | ||
for callback in self.callbacks: | ||
callback.set_params(params) | ||
MasterSkepticista marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def set_tensor_db(self, tensor_db): | ||
self.tensor_db = tensor_db | ||
if tensor_db: | ||
for callback in self.callbacks: | ||
callback.set_tensor_db(tensor_db) | ||
|
||
def _add_default_callbacks(self, add_memory_profiler): | ||
self._memory_profiler = None | ||
for cb in self.callbacks: | ||
if isinstance(cb, MemoryProfiler): | ||
self._memory_profiler = cb | ||
MasterSkepticista marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if add_memory_profiler and self._memory_profiler is None: | ||
MasterSkepticista marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._memory_profiler = MemoryProfiler() | ||
self.callbacks.append(self._memory_profiler) | ||
|
||
def on_round_begin(self, round_num: int, logs=None): | ||
for callback in self.callbacks: | ||
callback.on_round_begin(round_num, logs) | ||
|
||
def on_round_end(self, round_num: int, logs=None): | ||
for callback in self.callbacks: | ||
callback.on_round_end(round_num, logs) | ||
|
||
def on_experiment_begin(self, logs=None): | ||
for callback in self.callbacks: | ||
callback.on_experiment_begin(logs) | ||
|
||
def on_experiment_end(self, logs=None): | ||
for callback in self.callbacks: | ||
callback.on_experiment_end(logs) | ||
|
||
|
||
def _flatten(l): | ||
"""Flatten a possibly-nested tree of lists.""" | ||
for elem in l: | ||
if isinstance(elem, list): | ||
yield from _flatten(elem) | ||
else: | ||
yield elem |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2020-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from openfl.callbacks.callback import Callback | ||
|
||
|
||
class LambdaCallback(Callback): | ||
"""Custom on-the-fly callbacks. | ||
|
||
This callback can be constructed with functions that will be called | ||
at the appropriate time during the life-cycle of a Federated Learning experiment. | ||
Certain callbacks may expect positional arguments, for example: | ||
|
||
* on_round_begin: expects `round_num` as a positional argument. | ||
* on_round_end: expects `round_num` as a positional argument. | ||
|
||
Args: | ||
on_round_begin: called at the beginning of every round. | ||
on_round_end: called at the end of every round. | ||
on_experiment_begin: called at the beginning of an experiment. | ||
on_experiment_end: called at the end of an experiment. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
on_round_begin=None, | ||
on_round_end=None, | ||
on_experiment_begin=None, | ||
on_experiment_end=None, | ||
): | ||
super().__init__() | ||
if on_round_begin is not None: | ||
self.on_round_begin = on_round_begin | ||
if on_round_end is not None: | ||
self.on_round_end = on_round_end | ||
if on_experiment_begin is not None: | ||
self.on_experiment_begin = on_experiment_begin | ||
if on_experiment_end is not None: | ||
self.on_experiment_end = on_experiment_end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2020-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Memory Profiler callback.""" | ||
|
||
import json | ||
import logging | ||
import os | ||
|
||
import psutil | ||
|
||
from openfl.callbacks.callback import Callback | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MemoryProfiler(Callback): | ||
"""Profile memory usage of the current process at the end of each round. | ||
|
||
Attributes: | ||
log_dir: If set, writes logs as lines of JSON. | ||
""" | ||
|
||
def __init__(self, log_dir: str = "./logs/"): | ||
super().__init__() | ||
self.log_dir = None | ||
if log_dir: | ||
os.makedirs(log_dir, exist_ok=True) | ||
self.log_dir = log_dir | ||
|
||
def on_round_end(self, round_num: int, logs=None): | ||
origin = self.params["origin"] | ||
|
||
info = _get_memory_usage() | ||
info["round_number"] = round_num | ||
info["origin"] = origin | ||
|
||
logger.info(f"Round {round_num}: Memory usage: {info}") | ||
if self.log_dir: | ||
with open(os.path.join(self.log_dir, f"{origin}_memory_usage.json"), "a") as f: | ||
f.write(json.dumps(info) + "\n") | ||
|
||
|
||
def _get_memory_usage() -> dict: | ||
process = psutil.Process(os.getpid()) | ||
virtual_memory = psutil.virtual_memory() | ||
swap_memory = psutil.swap_memory() | ||
info = { | ||
"process_memory": round(process.memory_info().rss / (1024**2), 2), | ||
"virtual_memory/total": round(virtual_memory.total / (1024**2), 2), | ||
"virtual_memory/available": round(virtual_memory.available / (1024**2), 2), | ||
"virtual_memory/percent": virtual_memory.percent, | ||
"virtual_memory/used": round(virtual_memory.used / (1024**2), 2), | ||
"virtual_memory/free": round(virtual_memory.free / (1024**2), 2), | ||
"virtual_memory/active": round(virtual_memory.active / (1024**2), 2), | ||
"virtual_memory/inactive": round(virtual_memory.inactive / (1024**2), 2), | ||
"virtual_memory/buffers": round(virtual_memory.buffers / (1024**2), 2), | ||
"virtual_memory/cached": round(virtual_memory.cached / (1024**2), 2), | ||
"virtual_memory/shared": round(virtual_memory.shared / (1024**2), 2), | ||
"swap_memory/total": round(swap_memory.total / (1024**2), 2), | ||
"swap_memory/used": round(swap_memory.used / (1024**2), 2), | ||
"swap_memory/free": round(swap_memory.free / (1024**2), 2), | ||
"swap_memory/percent": swap_memory.percent, | ||
} | ||
return info |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on my read through, I see that the
logs
named-arg is only used byMetricWriter
.For extensibility, would it make more sense to have a
**kwargs
instead?We would still need to consider how these "dynamic" named args are sent by the aggregator/collaborator but this could be maybe thought of in subsequent/future changes to the
callbacks
API.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking,
**kwargs
should be avoided. To pass information that specific callbacks would need, we haveparams
that are declared at the time of instantiatingcallbacks
list. This makes it clear what attributescallbacks
will have access to, internally. Dynamic setting makes it harder to inspect and test code.openfl/openfl/callbacks/callback_list.py
Lines 18 to 20 in 38f7c30
Currently,
logs
are the default because mostcallbacks
generally trigger actions based on metrics - say writinglogs
to file, trigger saving a checkpoint iflogs["loss"] < 0.2
, writelogs
to tensorboard etc.That said, I agree that we can expand this argument list further - if more callbacks benefit from them.