diff --git a/fedeca/algorithms/torch_webdisco_algo.py b/fedeca/algorithms/torch_webdisco_algo.py index 36989cd8..5d34f2b1 100644 --- a/fedeca/algorithms/torch_webdisco_algo.py +++ b/fedeca/algorithms/torch_webdisco_algo.py @@ -1,7 +1,6 @@ """Implement webdisco algorithm with Torch.""" import copy from copy import deepcopy -from math import sqrt from pathlib import Path from typing import Any, List, Optional, Union @@ -11,7 +10,6 @@ from autograd import elementwise_grad from autograd import numpy as anp from lifelines.utils import StepSizer -from pandas.api.types import is_numeric_dtype from scipy.linalg import norm from scipy.linalg import solve as spsolve from substrafl.algorithms.pytorch import weight_manager diff --git a/fedeca/utils/survival_utils.py b/fedeca/utils/survival_utils.py index 20069cab..737ed471 100644 --- a/fedeca/utils/survival_utils.py +++ b/fedeca/utils/survival_utils.py @@ -1297,6 +1297,9 @@ def robust_sandwich_variance_pooled( model. The sandwich variance estimator is a robust estimator of the variance which accounts for the lack of dependence between the samples due to the introduction of weights for example. + + Parameters + ---------- X_norm : np.ndarray or torch.Tensor Input feature matrix of shape (n_samples, n_features). y : np.ndarray or torch.Tensor @@ -1309,6 +1312,11 @@ def robust_sandwich_variance_pooled( Weights associated with each sample, with shape (n_samples,) scaled_variance_matrix : np.ndarray or torch.Tensor Classical scaled variance of the Cox model estimator. + + Returns + ------- + np.ndarray + The robust sandwich variance estimator. """ n_samples, n_features = X_norm.shape @@ -1357,8 +1365,7 @@ def robust_sandwich_variance_pooled( def km_curve(t, n, d, tmax=5000): - """Computes Kaplan-Meier (KM) curve based on unique event times, number of - individuals at risk and number of deaths. + """Compute Kaplan-Meier (KM) curve. This function is typically used in conjunction with `compute_events_statistics`. Note that the variance is computed @@ -1482,8 +1489,7 @@ def compute_events_statistics(times, events): def aggregate_events_statistics(list_t_n_d): - """Aggregates (sums) events statistics from different centers, returning a single - tuple with the same format. + """Aggregate (sums) events statistics from different centers. Parameters ---------- @@ -1514,8 +1520,7 @@ def aggregate_events_statistics(list_t_n_d): def extend_events_to_common_grid(list_t_n_d, t_common): - """Extends a list of heterogeneous times, number of people at risk and number of - death on a common grid. + """Extend a list of heterogeneous times, number of people at risk on common grid. This method is an internal utility for `aggregate_events_statistics`.