Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add surv.finegray _coxph #417

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ Authors@R: c(
comment = c(ORCID = "0000-0002-3609-8674")),
person("Lukas", "Burk", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")),
person("Lona", "Koers", , "[email protected]", role = "ctb")
person("Lona", "Koers", , "[email protected]", role = "ctb"),
person("Andrzej", "Galecki", "[email protected]", role ="ctb")
)
Description: Extra learners for use in mlr3.
License: LGPL-3
Expand Down
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export(LearnerSurvCoxtime)
export(LearnerSurvDNNSurv)
export(LearnerSurvDeephit)
export(LearnerSurvDeepsurv)
export(LearnerSurvFineGrayCoxPH)
export(LearnerSurvFlexible)
export(LearnerSurvGAMBoost)
export(LearnerSurvGBM)
Expand Down Expand Up @@ -143,11 +144,23 @@ importFrom(mlr3,LearnerRegr)
importFrom(mlr3,lrn)
importFrom(mlr3,lrns)
importFrom(mlr3,mlr_learners)
importFrom(mlr3proba,LearnerSurv)
importFrom(mlr3proba,TaskSurv)
importFrom(paradox,p_dbl)
importFrom(paradox,p_fct)
importFrom(paradox,p_int)
importFrom(paradox,p_lgl)
importFrom(paradox,p_uty)
importFrom(paradox,ps)
importFrom(stats,as.formula)
importFrom(stats,formula)
importFrom(stats,na.omit)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(survival,basehaz)
importFrom(survival,coxph)
importFrom(survival,coxph.control)
importFrom(survival,finegray)
importFrom(utils,capture.output)
importFrom(utils,getFromNamespace)
importFrom(utils,packageVersion)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# dev

# mlr3extralearners (unreleased)

* Added new survival learner `surv.finegray_coxph` for Fine-Gray competing risks model with Cox PH.

# mlr3extralearners 1.0.0

* Add "Prediction types" doc section for all 30 survival learners + make sure it is consistent #347
Expand Down
148 changes: 148 additions & 0 deletions R/learner_survival_surv_finegray_coxph.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#' @title Fine-Gray Competing Risks Model with Cox Proportional Hazards
#' @description
#' A learner for fitting a Fine-Gray competing risks model using Cox proportional hazards.
#' Estimates subdistribution hazard for a specified event type with competing events, supporting weights.
#' @section Usage:
#' ```
#' learner <- LearnerSurvFineGrayCoxPH$new()
#' ```
#' @section Parameters:
#' - `ties`: Character, method for handling ties ("efron" (default), "breslow", "exact").
#' - `iter.max`: Integer, max iterations for Cox fit (default: 20, range: 1-1000).
#' - `eps`: Numeric, convergence threshold (default: 1e-9, range: 1e-12 to 1e-4).
#' - `robust`: Logical, compute robust variance (default: FALSE).
#' - `target_event`: Event type to model (default: second level if NULL).
#' - `singular.ok`: Logical, allow singular predictors (default: TRUE).
#' @section Predict Types:
#' - `crank`: Continuous ranking (linear predictor).
#' - `lp`: Linear predictor.
#' - `distr`: Survival probabilities (matrix with times as colnames).
#' @section Properties:
#' - Supports observation weights via the task's weights role (set via `weights` argument or `$set_col_roles()`).
#' @return For `predict()`, a list with `crank` (numeric), `lp` (numeric), and `distr`
#' (matrix of survival probabilities with times as colnames).
#' @section Methods:
#' - `new()`: Alias for `initialize()`, creates a new instance of the learner with default parameters.
#' - `train(task)`: Train the model on a survival task.
#' - `predict(task)`: Predict on new data from a trained model.
#' @importFrom R6 R6Class
#' @importFrom mlr3proba LearnerSurv TaskSurv
#' @importFrom paradox ps p_fct p_int p_dbl p_lgl p_uty
#' @importFrom survival finegray coxph basehaz coxph.control
#' @export
LearnerSurvFineGrayCoxPH <- R6::R6Class("LearnerSurvFineGrayCoxPH",
inherit = mlr3proba::LearnerSurv,
public = list(
#' @method initialize LearnerSurvFineGrayCoxPH
#' @description
#' Initialize a new Fine-Gray Cox PH learner with its default hyperparameter settings.
#' No arguments are required; hyperparameters are defined and set via the `paradox` parameter set.
initialize = function() {
ps <- paradox::ps(
ties = p_fct(default = "efron", levels = c("efron", "breslow", "exact"), tags = "train"),
iter.max = p_int(default = 20L, lower = 1L, upper = 1000L, tags = "train"),
eps = p_dbl(default = 1e-9, lower = 1e-12, upper = 1e-4, tags = "train"),
robust = p_lgl(default = FALSE, tags = "train"),
target_event = p_uty(default = NULL, tags = "train",
custom_check = function(x) {
if (is.null(x)) return(TRUE)
if (is.character(x) || is.numeric(x)) return(TRUE)
"target_event must be NULL, a character, or a numeric index"
}),
singular.ok = p_lgl(default = TRUE, tags = "train")
)
ps$values <- list(
ties = "efron",
iter.max = 20L,
eps = 1e-9,
robust = FALSE,
target_event = NULL,
singular.ok = TRUE
)
super$initialize(
id = "surv.finegray_coxph",
param_set = ps,
feature_types = c("logical", "integer", "numeric", "factor"),
predict_types = c("crank", "lp", "distr"),
properties = "weights",
packages = "survival",
label = "Fine-Gray Competing Risks Model with CoxPH",
man = "mlr3SurvUtils::LearnerSurvFineGrayCoxPH"
)
}
),
private = list(
basehaz = NULL,
.train = function(task) {
#print("ParamSet in self:")
#print(self$param_set) # Check if param_set exists
pv <- self$param_set$get_values(tags = "train")
#print("pv contents:")
#print(pv) # Debug: Check what pv contains
if (is.null(pv$iter.max) || !is.integer(pv$iter.max) || pv$iter.max < 1) {
stop("Invalid iter.max: ", pv$iter.max)
}
row_ids <- task$row_ids
full_data <- as.data.frame(task$backend$data(rows = row_ids,
cols = c(task$target_names, task$feature_names)))
features <- task$feature_names
if (length(features) == 0) stop("No features provided!")
time_col <- task$target_names[1]
event_col <- task$target_names[2]
event_levels <- task$levels()[[event_col]]
full_data$id <- seq_len(nrow(full_data))
weights <- NULL
if ("weights" %in% task$properties) {
weights_data <- task$weights
if (is.null(weights_data) || !"weight" %in% names(weights_data)) {
stop("No weights defined in task")
}
all_weights <- weights_data$weight
all_row_ids <- weights_data$row_id
weight_map <- setNames(all_weights, all_row_ids)
weights <- weight_map[as.character(row_ids)]
if (any(is.na(weights))) stop("Missing weights for some row_ids")
}
form <- as.formula(paste("Surv(", time_col, ",", event_col, ") ~",
paste(c(features, "id"), collapse = " + ")))
if (length(event_levels) < 3) {
stop("Event status must have at least 3 levels (censored, main event, competing risk)")
}
target_event <- if (is.null(pv$target_event)) {
event_levels[2]
} else {
if (is.numeric(pv$target_event)) {
event_levels[pv$target_event]
} else {
pv$target_event
}
}
if (!target_event %in% event_levels) stop("target_event not in event levels")
fg_data <- survival::finegray(form, data = full_data, etype = target_event, weights = weights)
if ("weights" %in% task$properties) {
matched_indices <- match(fg_data$id, as.integer(names(weight_map)))
fg_data$fgwt <- weights[matched_indices]
}
cox_formula <- as.formula(paste("Surv(fgstart, fgstop, fgstatus) ~",
paste(features, collapse = " + ")))
model <- survival::coxph(
cox_formula, data = fg_data, weights = fg_data$fgwt,
control = coxph.control(eps = pv$eps, iter.max = pv$iter.max),
robust = pv$robust, singular.ok = pv$singular.ok, ties = pv$ties
)
basehaz <- survival::basehaz(model, centered = FALSE)
private$basehaz <- list(time = basehaz$time, cumhaz = basehaz$hazard)
model
},
.predict = function(task) {
newdata <- as.data.frame(task$data(rows = task$row_ids, cols = task$feature_names))
lp <- predict(self$model, newdata = newdata, type = "lp")
cumhaz <- private$basehaz$cumhaz
time_order <- order(private$basehaz$time)
surv <- exp(-outer(exp(lp), cumhaz, "*"))
surv <- surv[, time_order, drop = FALSE]
colnames(surv) <- private$basehaz$time[time_order]
list(crank = lp, lp = lp, distr = surv)
}
)
)
74 changes: 74 additions & 0 deletions tests/testthat/test-learner_survival_surv_finegray_coxph.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
library(testthat)
library(mlr3)
library(mlr3proba)

test_that("surv.finegray_coxph trains and predicts", {
task = TaskSurv$new(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new mlr3proba update, you won't be ableto create a TaskSurv like this (with 3 event types). TaskCompRisks will be responsible for that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @bblodfon for your responses.

I noticed that in your version of mlr3proba the TaskSurv does not support type="counting". Is there another task generator that takes this responsibility?
I am also wondering whether a name similar to TaskMultStates for this type of task generator would be more "inclusive" and consistent with a convention used in Surv definition instead of TaskCompRisks.
Please consider adding an argument cmp_event_labels that allows the user to pass this information.
Thank you again

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I noticed that in your version of mlr3proba the TaskSurv does not support type="counting". Is there another task generator that takes this responsibility?

It's a new Task Class => TaskCompRisks, please see PR here, and I guess more helpful would be to see the tests.

I am also wondering whether a name similar to TaskMultStates for this type of task generator would be more "inclusive" and consistent with a convention used in Surv definition instead of TaskCompRisks

We noticed that as well, but there is a lot of stuff that is different (though conceptually competing risks is a subcase of multi state modeling), especially the prediction types (state probabilities can go up and down across time, CIF is always increasing), so we decided to split for now.

Please consider adding an argument cmp_event_labels that allows the user to pass this information.

I have, it just was so much easier to have the event types with intergers 0,1,2,3,... due to other subsetting/filtering issues with survival::Surv() that is used inside the competing risks task. I could store the event column character mapping stored somewhere I guess, but it doesn't change anything in the modeling stuff we do...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the learner here to a lrn("cmprsk.finegray") that outputs a CIF as mentioned in #416 would be for the best! I included the Aalen-Johansen estimator in the mlr3proba PR so that would be a good template to follow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: "I noticed that in your version of mlr3proba the TaskSurv does not support type="counting". Is there another task generator that takes this responsibility?'

This issue pertains to discussion. More specifically the code

library(mlr3)
library(mlr3proba)
mytask = as_task_surv(survival::heart, time = "start", time2 = "stop", 
         event = "event", type="counting")

worked on Oct 2nd, 2024, but now the new version of mlr3proba does not support type="counting" option and throws an error:

Error in .__TaskSurv__initialize(self = self, private = private, super = super, :
Assertion on 'type' failed: Must be element of set {'right','left','interval'}, but is 'counting'.

Would you consider reinstating type="counting" option?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for now answering directly - simply put, removing the type = "counting" is a design choice.

We are trying to see how we can address different censoring types and data formats (time-dependent covariates, start/stop data, etc.) and task types (multi-state vs single-event for example). The type argument above was about defining the type of censoring, but "counting" refers to the general [start, stop) data format (terminology was from Surv(...) which kinda mixes all these things). i.e. usually you also need an id column as one observation can have multiple rows in such a long format dataset. Not easy to see how we can reconcile this but definitely not impossible, we just need to make some design choices in mlr3/mlr3proba at some point.

id = "test",
backend = data.frame(
time = c(1, 2, 3, 4, 5),
event = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")),
x1 = c(1, 2, 3, 4, 5)
),
time = "time",
event = "event",
type = "mstate"
)

learner = lrn("surv.finegray_coxph")
expect_learner(learner)

# Train
learner$train(task)
expect_true(!is.null(learner$model))

# Predict
p = learner$predict(task)
expect_s3_class(p, "PredictionSurv")
expect_numeric(p$crank, len = 5)
expect_numeric(p$lp, len = 5)
expect_s3_class(p$distr, "Distribution") # Check for Distribution, not matrix
expect_true(all(c("crank", "lp", "distr") %in% names(p))) # Check names are present
})

test_that("surv.finegray_coxph handles weights", {
task = TaskSurv$new(
id = "test_weights",
backend = data.frame(
time = c(1, 2, 3, 4, 5),
event = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")),
x1 = c(1, 2, 3, 4, 5),
weight = c(1, 2, 1, 2, 1) # Add weights to backend
),
time = "time",
event = "event",
type = "mstate"
)
task$set_col_roles("weight", roles = "weight") # Set weight role

learner = lrn("surv.finegray_coxph")
learner$train(task)
p = learner$predict(task)
expect_s3_class(p, "PredictionSurv")
expect_numeric(p$crank, len = 5)
expect_numeric(p$lp, len = 5)
expect_s3_class(p$distr, "Distribution")
expect_true(all(c("crank", "lp", "distr") %in% names(p)))
})

test_that("surv.finegray_coxph with mstate task", {
data = data.frame(
time = c(1, 2, 3, 4, 5),
status = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")),
x = c(1, 2, 3, 4, 5)
)
task = TaskSurv$new("mstate_test", backend = data, time = "time", event = "status", type = "mstate")
learner = lrn("surv.finegray_coxph")
learner$train(task)
p = learner$predict(task)
expect_s3_class(p, "PredictionSurv")
expect_numeric(p$crank, len = 5)
expect_numeric(p$lp, len = 5)
expect_s3_class(p$distr, "Distribution")
expect_true(all(c("crank", "lp", "distr") %in% names(p)))
})