From 0859da6da0b98818180a6c9e5580018744d7079f Mon Sep 17 00:00:00 2001 From: Honghao LI Date: Mon, 22 Jan 2024 12:59:04 +0000 Subject: [PATCH] replace print with `logging` messages --- fedeca/fedeca_core.py | 47 +++++++++++++++++---------------- fedeca/utils/substrafl_utils.py | 19 ++++++++----- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/fedeca/fedeca_core.py b/fedeca/fedeca_core.py index 3e731c6a..9e3dd3cd 100644 --- a/fedeca/fedeca_core.py +++ b/fedeca/fedeca_core.py @@ -1,4 +1,5 @@ """Federate causal inference on distributed data.""" +import logging import sys import time from collections.abc import Callable @@ -34,6 +35,8 @@ from fedeca.utils.substrafl_utils import get_outmodel_function from fedeca.utils.survival_utils import BaseSurvivalEstimator, CoxPHModelTorch +logger = logging.getLogger(__name__) + class FedECA(Experiment, BaseSurvivalEstimator): """FedECA class tthat performs Federated IPTW.""" @@ -308,7 +311,7 @@ def check_cp_status(self, idx=0): model_name = "Robust Variance" training_type = "estimation" - print(f"Waiting on {model_name} {training_type} to finish...") + logger.info(f"Waiting on {model_name} {training_type} to finish...") t1 = time.time() t2 = t1 while (t2 - t1) < self.timeout: @@ -316,7 +319,7 @@ def check_cp_status(self, idx=0): self.compute_plan_keys[idx].key ).status if status == ComputePlanStatus.done: - print( + logger.info( f"""Compute plan {self.compute_plan_keys[0].key} of {model_name} has finished !""" ) @@ -336,7 +339,7 @@ def check_cp_status(self, idx=0): ): pass else: - print( + logger.warning( f"""Compute plan status is {status}, this shouldn't happen, sleeping {self.time_sleep} and retrying until timeout {self.timeout}""" ) @@ -518,7 +521,7 @@ def fit( if backend_type != "remote" and ( urls is not None or server_org_id is not None or tokens is not None ): - print( + logger.warning( "urls, server_org_id and tokens are ignored if backend_type is " "not remote; Make sure that you launched the fit with the right" " combination of parameters." @@ -598,9 +601,7 @@ def __init__(self): ) # We put WebDisco in "robust" mode in the sense that we ask it # to store all needed quantities for robust variance estimation - self.strategies[ - 1 - ].algo._robust = True # not sufficient for serialization + self.strategies[1].algo._robust = True # not sufficient for serialization # possible only because we added robust as a kwargs self.strategies[1].algo.kwargs.update({"robust": True}) # We need those two lines for the zip to consider all 3 @@ -616,9 +617,9 @@ def __init__(self): def run(self, targets: Union[pd.DataFrame, None] = None): """Run the federated iptw algorithms.""" del targets - print("Careful for now the argument target is ignored completely") + logger.info("Careful for now the argument target is ignored completely") # We first run the propensity model - print("Fitting the propensity model...") + logger.info("Fitting the propensity model...") t1 = time.time() super().run(1) @@ -629,11 +630,11 @@ def run(self, targets: Union[pd.DataFrame, None] = None): ) else: self.performances_propensity_model = self.performances_strategies[0] - print(self.performances_propensity_model) + logger.info(self.performances_propensity_model) t2 = time.time() self.propensity_model_fit_time = t2 - t1 - print(f"Time to fit Propensity model {self.propensity_model_fit_time}s") - print("Finished, recovering the final propensity model from substra") + logger.info(f"Time to fit Propensity model {self.propensity_model_fit_time}s") + logger.info("Finished, recovering the final propensity model from substra") # TODO to add the opportunity to use the targets you have to either: # give the full targets to every client as a kwargs of their Algo # so effectively one would need to reinstantiate algos objects or to @@ -665,7 +666,7 @@ def run(self, targets: Union[pd.DataFrame, None] = None): for t in self.train_data_nodes: t.keep_intermediate_states = True - print("Fitting propensity weighted Cox model...") + logger.info("Fitting propensity weighted Cox model...") t1 = time.time() super().run(1) @@ -673,8 +674,8 @@ def run(self, targets: Union[pd.DataFrame, None] = None): self.check_cp_status(idx=1) t2 = time.time() self.webdisco_fit_time = t2 - t1 - print(f"Time to fit WebDisco {self.webdisco_fit_time}s") - print("Finished fitting weighted Cox model.") + logger.info(f"Time to fit WebDisco {self.webdisco_fit_time}s") + logger.info("Finished fitting weighted Cox model.") self.total_fit_time = self.propensity_model_fit_time + self.webdisco_fit_time self.print_summary() @@ -683,19 +684,19 @@ def print_summary(self): assert ( len(self.compute_plan_keys) == 2 ), "You need to run the run method before getting the summary" - print("Evolution of performance of propensity model:") - print(self.performances_propensity_model) - print("Checking if the Cox model has converged:") + logger.info("Evolution of performance of propensity model:") + logger.info(self.performances_propensity_model) + logger.info("Checking if the Cox model has converged:") self.get_final_cox_model() - print("Computing summary...") + logger.info("Computing summary...") self.compute_summary() - print("Final partial log-likelihood:") - print(self.ll) - print(self.results_) + logger.info("Final partial log-likelihood:") + logger.info(self.ll) + logger.info(self.results_) def get_final_cox_model(self): """Retrieve final cox model.""" - print("Retrieving final hessian and log-likelihood") + logger.info("Retrieving final hessian and log-likelihood") if not self.simu_mode: cp = self.compute_plan_keys[1].key else: diff --git a/fedeca/utils/substrafl_utils.py b/fedeca/utils/substrafl_utils.py index 5637512d..2c72b2e6 100644 --- a/fedeca/utils/substrafl_utils.py +++ b/fedeca/utils/substrafl_utils.py @@ -1,4 +1,5 @@ """Utils functions for Substra.""" +import logging import os import pickle import tempfile @@ -35,6 +36,8 @@ _check_environment_compatibility, ) +logger = logging.getLogger(__name__) + class Experiment: """Experiment class.""" @@ -108,11 +111,11 @@ def __init__( if metrics_dicts_list and not all( [len(t.metric_functions) == 0 for t in self.test_data_nodes] ): - print( + logger.warning( """WARNING: you are passing metrics to test data nodes with existing metric_functions this will overwrite them""" ) - print( + logger.warning( [ (f"Client {i}", t.metric_functions) for i, t in enumerate(self.test_data_nodes) @@ -253,7 +256,7 @@ def run(self, num_strategies_to_run=None): # If no AggregationNode is given we take the first one if self.aggregation_node is None: - print("Using the first client as a server.") + logger.info("Using the first client as a server.") kwargs_agg_node = { "organization_id": self.train_data_nodes[0].organization_id } @@ -315,12 +318,12 @@ def run(self, num_strategies_to_run=None): scores = [t.scores for t in self.test_data_nodes] robust_cox_variance = False for idx, s in enumerate(scores): - print(f"====Client {idx}====") + logger.info(f"====Client {idx}====") try: - print(s[-1]) + logger.info(s[-1]) except IndexError: robust_cox_variance = True - print("No metric") + logger.info("No metric") # TODO Check that it is well formatted it's probably not self.performances_strategies.append(pd.DataFrame(xp_output)) # Hacky hacky hack @@ -515,7 +518,9 @@ def make_substrafl_torch_dataset_class( [t in [event_col, duration_col] for t in target_cols] ) if len(target_cols) == 1: - print(f"Making a dataset class to fit a model to predict {target_cols[0]}") + logger.info( + f"Making a dataset class to fit a model to predict {target_cols[0]}" + ) columns_to_drop = [event_col, duration_col] elif len(target_cols) == 2: assert set(target_cols) == set(