Skip to content

Commit

Permalink
replace print with logging messages
Browse files Browse the repository at this point in the history
  • Loading branch information
honghaoli42 committed Jan 22, 2024
1 parent f4dba29 commit 0859da6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
47 changes: 24 additions & 23 deletions fedeca/fedeca_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Federate causal inference on distributed data."""
import logging
import sys
import time
from collections.abc import Callable
Expand Down Expand Up @@ -34,6 +35,8 @@
from fedeca.utils.substrafl_utils import get_outmodel_function
from fedeca.utils.survival_utils import BaseSurvivalEstimator, CoxPHModelTorch

logger = logging.getLogger(__name__)


class FedECA(Experiment, BaseSurvivalEstimator):
"""FedECA class tthat performs Federated IPTW."""
Expand Down Expand Up @@ -308,15 +311,15 @@ def check_cp_status(self, idx=0):
model_name = "Robust Variance"
training_type = "estimation"

print(f"Waiting on {model_name} {training_type} to finish...")
logger.info(f"Waiting on {model_name} {training_type} to finish...")
t1 = time.time()
t2 = t1
while (t2 - t1) < self.timeout:
status = self.ds_client.get_compute_plan(
self.compute_plan_keys[idx].key
).status
if status == ComputePlanStatus.done:
print(
logger.info(
f"""Compute plan {self.compute_plan_keys[0].key} of {model_name} has
finished !"""
)
Expand All @@ -336,7 +339,7 @@ def check_cp_status(self, idx=0):
):
pass
else:
print(
logger.warning(
f"""Compute plan status is {status}, this shouldn't happen, sleeping
{self.time_sleep} and retrying until timeout {self.timeout}"""
)
Expand Down Expand Up @@ -518,7 +521,7 @@ def fit(
if backend_type != "remote" and (
urls is not None or server_org_id is not None or tokens is not None
):
print(
logger.warning(
"urls, server_org_id and tokens are ignored if backend_type is "
"not remote; Make sure that you launched the fit with the right"
" combination of parameters."
Expand Down Expand Up @@ -598,9 +601,7 @@ def __init__(self):
)
# We put WebDisco in "robust" mode in the sense that we ask it
# to store all needed quantities for robust variance estimation
self.strategies[
1
].algo._robust = True # not sufficient for serialization
self.strategies[1].algo._robust = True # not sufficient for serialization
# possible only because we added robust as a kwargs
self.strategies[1].algo.kwargs.update({"robust": True})
# We need those two lines for the zip to consider all 3
Expand All @@ -616,9 +617,9 @@ def __init__(self):
def run(self, targets: Union[pd.DataFrame, None] = None):
"""Run the federated iptw algorithms."""
del targets
print("Careful for now the argument target is ignored completely")
logger.info("Careful for now the argument target is ignored completely")
# We first run the propensity model
print("Fitting the propensity model...")
logger.info("Fitting the propensity model...")
t1 = time.time()
super().run(1)

Expand All @@ -629,11 +630,11 @@ def run(self, targets: Union[pd.DataFrame, None] = None):
)
else:
self.performances_propensity_model = self.performances_strategies[0]
print(self.performances_propensity_model)
logger.info(self.performances_propensity_model)
t2 = time.time()
self.propensity_model_fit_time = t2 - t1
print(f"Time to fit Propensity model {self.propensity_model_fit_time}s")
print("Finished, recovering the final propensity model from substra")
logger.info(f"Time to fit Propensity model {self.propensity_model_fit_time}s")
logger.info("Finished, recovering the final propensity model from substra")
# TODO to add the opportunity to use the targets you have to either:
# give the full targets to every client as a kwargs of their Algo
# so effectively one would need to reinstantiate algos objects or to
Expand Down Expand Up @@ -665,16 +666,16 @@ def run(self, targets: Union[pd.DataFrame, None] = None):
for t in self.train_data_nodes:
t.keep_intermediate_states = True

print("Fitting propensity weighted Cox model...")
logger.info("Fitting propensity weighted Cox model...")
t1 = time.time()
super().run(1)

if not self.simu_mode:
self.check_cp_status(idx=1)
t2 = time.time()
self.webdisco_fit_time = t2 - t1
print(f"Time to fit WebDisco {self.webdisco_fit_time}s")
print("Finished fitting weighted Cox model.")
logger.info(f"Time to fit WebDisco {self.webdisco_fit_time}s")
logger.info("Finished fitting weighted Cox model.")
self.total_fit_time = self.propensity_model_fit_time + self.webdisco_fit_time
self.print_summary()

Expand All @@ -683,19 +684,19 @@ def print_summary(self):
assert (
len(self.compute_plan_keys) == 2
), "You need to run the run method before getting the summary"
print("Evolution of performance of propensity model:")
print(self.performances_propensity_model)
print("Checking if the Cox model has converged:")
logger.info("Evolution of performance of propensity model:")
logger.info(self.performances_propensity_model)
logger.info("Checking if the Cox model has converged:")
self.get_final_cox_model()
print("Computing summary...")
logger.info("Computing summary...")
self.compute_summary()
print("Final partial log-likelihood:")
print(self.ll)
print(self.results_)
logger.info("Final partial log-likelihood:")
logger.info(self.ll)
logger.info(self.results_)

def get_final_cox_model(self):
"""Retrieve final cox model."""
print("Retrieving final hessian and log-likelihood")
logger.info("Retrieving final hessian and log-likelihood")
if not self.simu_mode:
cp = self.compute_plan_keys[1].key
else:
Expand Down
19 changes: 12 additions & 7 deletions fedeca/utils/substrafl_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils functions for Substra."""
import logging
import os
import pickle
import tempfile
Expand Down Expand Up @@ -35,6 +36,8 @@
_check_environment_compatibility,
)

logger = logging.getLogger(__name__)


class Experiment:
"""Experiment class."""
Expand Down Expand Up @@ -108,11 +111,11 @@ def __init__(
if metrics_dicts_list and not all(
[len(t.metric_functions) == 0 for t in self.test_data_nodes]
):
print(
logger.warning(
"""WARNING: you are passing metrics to test data nodes with existing
metric_functions this will overwrite them"""
)
print(
logger.warning(
[
(f"Client {i}", t.metric_functions)
for i, t in enumerate(self.test_data_nodes)
Expand Down Expand Up @@ -253,7 +256,7 @@ def run(self, num_strategies_to_run=None):

# If no AggregationNode is given we take the first one
if self.aggregation_node is None:
print("Using the first client as a server.")
logger.info("Using the first client as a server.")
kwargs_agg_node = {
"organization_id": self.train_data_nodes[0].organization_id
}
Expand Down Expand Up @@ -315,12 +318,12 @@ def run(self, num_strategies_to_run=None):
scores = [t.scores for t in self.test_data_nodes]
robust_cox_variance = False
for idx, s in enumerate(scores):
print(f"====Client {idx}====")
logger.info(f"====Client {idx}====")
try:
print(s[-1])
logger.info(s[-1])
except IndexError:
robust_cox_variance = True
print("No metric")
logger.info("No metric")
# TODO Check that it is well formatted it's probably not
self.performances_strategies.append(pd.DataFrame(xp_output))
# Hacky hacky hack
Expand Down Expand Up @@ -515,7 +518,9 @@ def make_substrafl_torch_dataset_class(
[t in [event_col, duration_col] for t in target_cols]
)
if len(target_cols) == 1:
print(f"Making a dataset class to fit a model to predict {target_cols[0]}")
logger.info(
f"Making a dataset class to fit a model to predict {target_cols[0]}"
)
columns_to_drop = [event_col, duration_col]
elif len(target_cols) == 2:
assert set(target_cols) == set(
Expand Down

0 comments on commit 0859da6

Please sign in to comment.