Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Competing Risks #433

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e877d3c
refactor: TaskSurv, handling different censoring types
bblodfon Feb 6, 2025
f525738
refactor: TaskSurv, handling different censoring types
bblodfon Feb 6, 2025
0aaf0fe
Merge branch 'competing_risk' of https://github.com/mlr-org/mlr3proba…
bblodfon Feb 7, 2025
49f7ffa
updocs
bblodfon Feb 7, 2025
c09334b
updocs2
bblodfon Feb 7, 2025
b6c3598
add TaskCompRisks
bblodfon Feb 7, 2025
3fa3b83
make Task fields read-only
bblodfon Feb 7, 2025
e067d87
refer to mlr3 learner properties and remove non-up-to-date doc from m…
bblodfon Feb 7, 2025
0db1769
refine error expectation
bblodfon Feb 7, 2025
d70081b
supress possible warnings
bblodfon Feb 7, 2025
357c50c
update reflections adding competing risks
bblodfon Feb 7, 2025
4bb3a32
fix autotest
bblodfon Feb 7, 2025
b0c3bec
fix test
bblodfon Feb 7, 2025
10dd77a
add CRs section ot the website
bblodfon Feb 7, 2025
4ad3510
make Measure classes cloneable again
bblodfon Feb 18, 2025
c16140a
Merge branch 'main' into competing_risk
bblodfon Feb 21, 2025
44941de
add AJ estimator ref + reformat
bblodfon Feb 21, 2025
67327a5
refactor: all generated tasks in one file + survival::pbc gets time c…
bblodfon Feb 21, 2025
75eff3a
add some more reflections for comp risks tasks and learner
bblodfon Feb 21, 2025
54dce0b
add comp risks base learner class
bblodfon Feb 21, 2025
9e0782f
fix ref
bblodfon Feb 21, 2025
7fe75e1
update learners doc
bblodfon Feb 21, 2025
5e8f137
update measure docs (cloneable = TRUE)
bblodfon Feb 21, 2025
c793219
add Aalen-Johansen estimator for competing risks
bblodfon Feb 21, 2025
0cc9cd0
update files
bblodfon Feb 28, 2025
dbcda03
import setNames
bblodfon Feb 28, 2025
2f2e8aa
add filter() and unique_events() methods for cmprsk tasks
bblodfon Mar 1, 2025
d7160e7
small fix
bblodfon Mar 1, 2025
4d33ac7
add assert_cif_list() function
bblodfon Mar 1, 2025
fedca07
fixing task$truth() from cpmrisk tasks
bblodfon Mar 1, 2025
5e28626
doc the 'n' parameter
bblodfon Mar 1, 2025
6a64b31
rename
bblodfon Mar 1, 2025
764c88b
proper transform of PredictionSurv to data.table
bblodfon Mar 2, 2025
2231a6a
fix test
bblodfon Mar 2, 2025
b8b6602
remove unneeded namespace
bblodfon Mar 2, 2025
0134abd
add PredictionCompRisks and PredictionDataCompRisks classes and S3 me…
bblodfon Mar 6, 2025
657932a
refactor test
bblodfon Mar 6, 2025
1baf8db
some helper test functions + test for CR predictions
bblodfon Mar 6, 2025
990b20e
fix name
bblodfon Mar 6, 2025
8adca9c
fix doc
bblodfon Mar 6, 2025
57801e9
add new entries in pkgdown website
bblodfon Mar 6, 2025
9d0d923
better example of CIF predictions
bblodfon Mar 6, 2025
ae5069a
add type check on list assertion
bblodfon Mar 7, 2025
094fb2a
properly combine CIFs from many prediction objects
bblodfon Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.2
Collate:
'LearnerDens.R'
'LearnerCompRisks.R'
'aaa.R'
'LearnerCompRisksAalenJohansen.R'
'LearnerDens.R'
'LearnerDensHistogram.R'
'LearnerDensKDE.R'
'LearnerSurv.R'
Expand Down Expand Up @@ -120,20 +122,24 @@ Collate:
'PipeOpSurvAvg.R'
'PipeOpTaskSurvClassifDiscTime.R'
'PipeOpTaskSurvClassifIPCW.R'
'PredictionCompRisks.R'
'PredictionDataCompRisks.R'
'PredictionDataDens.R'
'PredictionDataSurv.R'
'PredictionDens.R'
'PredictionSurv.R'
'RcppExports.R'
'TaskCompRisks.R'
'TaskDens.R'
'TaskDens_zzz.R'
'TaskGeneratorCoxed.R'
'TaskGeneratorSimdens.R'
'TaskGeneratorSimsurv.R'
'TaskSurv.R'
'TaskSurv_zzz.R'
'Task_zzz.R'
'as_prediction_cmprsk.R'
'as_prediction_dens.R'
'as_prediction_surv.R'
'as_task_cmprisk.R'
'as_task_dens.R'
'as_task_surv.R'
'assertions.R'
Expand Down
20 changes: 19 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
# Generated by roxygen2: do not edit by hand

S3method(as.data.table,PredictionCompRisks)
S3method(as.data.table,PredictionDens)
S3method(as.data.table,PredictionSurv)
S3method(as_prediction,PredictionDataCompRisks)
S3method(as_prediction,PredictionDataDens)
S3method(as_prediction,PredictionDataSurv)
S3method(as_prediction_cmprsk,PredictionCompRisks)
S3method(as_prediction_cmprsk,data.frame)
S3method(as_prediction_dens,PredictionDens)
S3method(as_prediction_dens,data.frame)
S3method(as_prediction_surv,PredictionSurv)
S3method(as_prediction_surv,data.frame)
S3method(as_task_cmprsk,DataBackend)
S3method(as_task_cmprsk,TaskCompRisks)
S3method(as_task_cmprsk,data.frame)
S3method(as_task_dens,DataBackend)
S3method(as_task_dens,TaskDens)
S3method(as_task_dens,data.frame)
S3method(as_task_surv,DataBackend)
S3method(as_task_surv,TaskSurv)
S3method(as_task_surv,data.frame)
S3method(as_task_surv,formula)
S3method(autoplot,PredictionSurv)
S3method(autoplot,TaskDens)
S3method(autoplot,TaskSurv)
S3method(c,PredictionDataCompRisks)
S3method(c,PredictionDataDens)
S3method(c,PredictionDataSurv)
S3method(check_prediction_data,PredictionDataCompRisks)
S3method(check_prediction_data,PredictionDataDens)
S3method(check_prediction_data,PredictionDataSurv)
S3method(filter_prediction_data,PredictionDataCompRisks)
S3method(filter_prediction_data,PredictionDataSurv)
S3method(is_missing_prediction_data,PredictionDataCompRisks)
S3method(is_missing_prediction_data,PredictionDataDens)
S3method(is_missing_prediction_data,PredictionDataSurv)
S3method(pecs,PredictionSurv)
Expand All @@ -31,6 +41,8 @@ S3method(plot,TaskDens)
S3method(plot,TaskSurv)
export(.c_weight_survival_score)
export(.surv_return)
export(LearnerCompRisks)
export(LearnerCompRisksAalenJohansen)
export(LearnerDens)
export(LearnerDensHistogram)
export(LearnerDensKDE)
Expand Down Expand Up @@ -77,17 +89,22 @@ export(PipeOpResponseCompositor)
export(PipeOpSurvAvg)
export(PipeOpTaskSurvClassifDiscTime)
export(PipeOpTaskSurvClassifIPCW)
export(PredictionCompRisks)
export(PredictionDens)
export(PredictionSurv)
export(TaskCompRisks)
export(TaskDens)
export(TaskGeneratorCoxed)
export(TaskGeneratorSimdens)
export(TaskGeneratorSimsurv)
export(TaskSurv)
export(as_prediction_cmprsk)
export(as_prediction_dens)
export(as_prediction_surv)
export(as_task_cmprsk)
export(as_task_dens)
export(as_task_surv)
export(assert_cif_list)
export(assert_surv)
export(assert_surv_matrix)
export(breslow)
Expand Down Expand Up @@ -119,6 +136,7 @@ importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(survival,Surv)
importFrom(utils,data)
importFrom(utils,getFromNamespace)
Expand Down
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# mlr3proba dev

* fix: allow cloning of measures objects
* fix: allow cloning of measure objects
* refactor: `TaskSurv` uses only right, left or interval censoring, simplified code a lot in the methods
* feat: add `TaskCompRisks` class and `as_task_cmprk()` S3 methods (support for right-censored data only)
* fix: as.data.table() for `PredictionSurv` objects holds now a survival curve per observation as it should

# mlr3proba 0.7.4

Expand Down
47 changes: 47 additions & 0 deletions R/LearnerCompRisks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#' @title Competing Risks Learner
#'
#' @description
#' This Learner specializes [Learner] for competing risks problems:
#'
#' - `task_type` is set to `"cmprsk"`
#' - Creates [Prediction]s of class [PredictionCompRisks].
#' - The only currently available option for `predict_types` is `"cif"`, which
#' represents the predicted **cumulative incidence function** for each observation
#' in the test set.
#'
#' @template param_id
#' @template param_param_set
#' @template param_predict_types
#' @template param_feature_types
#' @template param_learner_properties
#' @template param_packages
#' @template param_label
#' @template param_man
#'
#' @family Learner
#' @export
#' @examples
#' library(mlr3)
#' # get all survival learners from mlr_learners:
#' lrns = mlr_learners$mget(mlr_learners$keys("^cmprsk"))
#' names(lrns)
#'
#' # get a specific learner from mlr_learners:
#' mlr_learners$get("cmprsk.aalen")
#' lrn("cmprsk.aalen")
LearnerCompRisks = R6Class("LearnerCompRisks",
inherit = Learner,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(), predict_types = "cif",
feature_types = character(), properties = character(),
packages = character(), label = NA_character_, man = NA_character_) {

super$initialize(
id = id, task_type = "cmprsk", param_set = param_set, predict_types = predict_types,
feature_types = feature_types, properties = properties,
packages = c("mlr3proba", packages), label = label, man = man
)
}
)
)
78 changes: 78 additions & 0 deletions R/LearnerCompRisksAalenJohansen.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#' @title Aalen Johansen Competing Risks Learner
#' @templateVar fullname LearnerCompRisksAalenJohansen
#' @templateVar id cmprsk.aalen
#' @template cmprsk_learner
#'
#' @description
#'
#' This learner estimates the Cumulative Incidence Function (CIF) for competing
#' risks using the empirical Aalen-Johansen (AJ) estimator.
#' The probability of transitioning to each competing event is computed via the
#' [survfit][survival::survfit.formula()] function.
#'
#' @references
#' `r format_bib("aalen_1978")`
#'
#' @export
LearnerCompRisksAalenJohansen = R6Class("LearnerCompRisksAalenJohansen",
inherit = LearnerCompRisks,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
model = p_lgl(default = FALSE, tags = "train")
)

super$initialize(
id = "cmprsk.aalen",
param_set = param_set,
predict_types = "cif",
feature_types = c("logical", "integer", "numeric", "factor"),
properties = "weights",
packages = "survival",
label = "Aalen Johansen Estimator",
man = "mlr3proba::mlr_learners_cmprsk.aalen"
)
}
),

private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")

if ("weights" %in% task$properties) {
pv$weights = task$weights$weight
}

invoke(survival::survfit,
formula = task$formula(1),
data = task$data(cols = task$target_names),
.args = pv
)
},

.predict = function(task) {
trans_mat = self$model$pstate
trans_mat = trans_mat[, -1] # remove (s0) => prob of staying censored (state 0)

times = self$model$time # unique train set time points
n_obs = task$nrow # number of test observations
CIF = stats::setNames(vector("list", ncol(trans_mat)), colnames(trans_mat))

for (i in seq_along(CIF)) {
CIF[[i]] = matrix(
data = rep(trans_mat[, i], times = n_obs),
nrow = n_obs,
byrow = TRUE,
dimnames = list(NULL, times)
)
}

list(cif = CIF)
}
)
)

#' @include aaa.R
register_learner("cmprsk.aalen", LearnerCompRisksAalenJohansen)
4 changes: 3 additions & 1 deletion R/LearnerSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ LearnerSurv = R6Class("LearnerSurv",
initialize = function(id, param_set = ps(), predict_types = "distr",
feature_types = character(), properties = character(),
packages = character(), label = NA_character_, man = NA_character_) {

super$initialize(
id = id, task_type = "surv", param_set = param_set, predict_types = predict_types,
feature_types = feature_types, properties = properties,
packages = c("mlr3proba", packages), label = label, man = man)
packages = c("mlr3proba", packages), label = label, man = man
)
}
)
)
2 changes: 1 addition & 1 deletion R/MeasureSurvGraf.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#' `r format_bib("graf_1999", "sonabend_2024", "kvamme_2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvICI.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#' cases.
#'
#' @references
#' `r format_bib("austin2020")`
#' `r format_bib("austin_2020")`
#'
#' @family calibration survival measures
#' @family distr survival measures
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#' `r format_bib("graf_1999", "sonabend_2024", "kvamme_2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#' @template details_trainG
#'
#' @references
#' `r format_bib("sonabend2024")`
#' `r format_bib("sonabend_2024")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("schemper_2000", "schmid_2011", "sonabend2024", "kvamme2023")`
#' `r format_bib("schemper_2000", "schmid_2011", "sonabend_2024", "kvamme_2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvClassifDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
private = list(
.train = function(input) {
task = input[[1L]]
assert_true(task$censtype == "right")
assert_true(task$cens_type == "right")
data = task$data()

if ("disc_status" %in% colnames(task$data())) {
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvClassifIPCW.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ PipeOpTaskSurvClassifIPCW = R6Class("PipeOpTaskSurvClassifIPCW",
task = input[[1]]

# checks
assert_true(task$censtype == "right")
assert_true(task$cens_type == "right")
tau = assert_numeric(self$param_set$values$tau, null.ok = FALSE)
max_event_time = max(task$unique_event_times())
stopifnot(tau <= max_event_time)
Expand Down
Loading