Skip to content

Commit

Permalink
Merge pull request #3 from payalcha/remove-log-metric
Browse files Browse the repository at this point in the history
Remove log_metric_callback and write_metric from aggregator
  • Loading branch information
payalcha authored Dec 10, 2024
2 parents 2a79c4d + b815f84 commit 46bc921
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.utilities import TaskResultKey, TensorKey, change_tags
from openfl.utilities.logs import get_memory_usage, write_memory_usage_to_file, write_metric
from openfl.utilities.logs import get_memory_usage, write_memory_usage_to_file


class Aggregator:
Expand All @@ -39,7 +39,6 @@ class Aggregator:
db_store_rounds* (int): Rounds to store in TensorDB.
logger: Object for logging.
write_logs (bool): Flag to enable log writing.
log_metric_callback: Callback for logging metrics.
best_model_score (optional): Score of the best model. Defaults to
None.
metric_queue (queue.Queue): Queue for metrics.
Expand Down Expand Up @@ -78,7 +77,6 @@ def __init__(
db_store_rounds=1,
write_logs=False,
log_memory_usage=False,
log_metric_callback=None,
initial_tensor_dict=None,
):
"""Initializes the Aggregator.
Expand Down Expand Up @@ -106,8 +104,6 @@ def __init__(
Defaults to 1.
write_logs (bool, optional): Whether to write logs. Defaults to
False.
log_metric_callback (optional): Callback for log metric. Defaults
to None.
**kwargs: Additional keyword arguments.
"""
self.logger = getLogger(__name__)
Expand Down Expand Up @@ -147,13 +143,6 @@ def __init__(

# Gathered together logging-related objects
self.write_logs = write_logs
self.log_metric_callback = log_metric_callback

if self.write_logs:
self.log_metric = write_metric
if self.log_metric_callback:
self.log_metric = log_metric_callback
self.logger.info("Using custom log metric: %s", self.log_metric)

self.best_model_score = None
self.metric_queue = queue.Queue()
Expand Down Expand Up @@ -664,14 +653,6 @@ def send_local_task_results(
}
self.metric_queue.put(metrics)
self.logger.metric("%s", str(metrics))
if self.write_logs:
self.log_metric(
collaborator_name,
task_name,
tensor_key.tensor_name,
float(value),
round_number,
)

task_results.append(tensor_key)

Expand Down Expand Up @@ -717,9 +698,7 @@ def _process_named_tensor(self, named_tensor, collaborator_name):
Returns:
tensor_key (TensorKey): Tensor key.
The tensorkey extracted from the protobuf.
nparray (np.array): Numpy array.
The numpy array associated with the returned tensorkey.
"""
raw_bytes = named_tensor.data_bytes
metadata = [
Expand Down Expand Up @@ -946,14 +925,6 @@ def _compute_validation_related_task_metrics(self, task_name):

self.metric_queue.put(metrics)
self.logger.metric("%s", metrics)
if self.write_logs:
self.log_metric(
"aggregator",
task_name,
tensor_key.tensor_name,
float(agg_results),
round_number,
)

# FIXME: Configurable logic for min/max criteria in saving best.
if "validate_agg" in tags:
Expand Down

0 comments on commit 46bc921

Please sign in to comment.