-
-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add surv.finegray _coxph #417
Open
agalecki
wants to merge
17
commits into
mlr-org:main
Choose a base branch
from
agalecki-forks:add-surv-finegray-coxph
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
d3f42ae
Create learner_survival_surv_finegray_coxph.R
agalecki 2180486
Update learner_survival_surv_finegray_coxph.R
agalecki dd7f81b
NS+Rd
agalecki fa92441
Create test-learner_survival_surv_finegray_coxph.R
agalecki 0349e63
DESCRIPTION+NEWs.md modified
agalecki b44d429
Update DESCRIPTION
agalecki af14ab9
Update DESCRIPTION
agalecki 97d3fb6
Merge branch 'add-surv-finegray-coxph' of https://github.com/agalecki…
agalecki d86a85c
Update NEWS.md
agalecki f164337
Update learner_survival_surv_finegray_coxph.R
agalecki fba3cf9
Update learner_survival_surv_finegray_coxph.R
agalecki 326728d
Update learner_survival_surv_finegray_coxph.R
agalecki f41656a
Update learner_survival_surv_finegray_coxph.R
agalecki 534b2d6
Update learner_survival_surv_finegray_coxph.R
agalecki fd011c1
Update learner_survival_surv_finegray_coxph.R
agalecki 65b6513
Update learner_survival_surv_finegray_coxph.R
agalecki 51aae95
Delete LearnerSurvFineGrayCoxPH.Rd
agalecki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,8 @@ Authors@R: c( | |
comment = c(ORCID = "0000-0002-3609-8674")), | ||
person("Lukas", "Burk", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0001-7528-3795")), | ||
person("Lona", "Koers", , "[email protected]", role = "ctb") | ||
person("Lona", "Koers", , "[email protected]", role = "ctb"), | ||
person("Andrzej", "Galecki", "[email protected]", role ="ctb") | ||
) | ||
Description: Extra learners for use in mlr3. | ||
License: LGPL-3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#' @title Fine-Gray Competing Risks Model with Cox Proportional Hazards | ||
#' @description | ||
#' A learner for fitting a Fine-Gray competing risks model using Cox proportional hazards. | ||
#' Estimates subdistribution hazard for a specified event type with competing events, supporting weights. | ||
#' @section Usage: | ||
#' ``` | ||
#' learner <- LearnerSurvFineGrayCoxPH$new() | ||
#' ``` | ||
#' @section Parameters: | ||
#' - `ties`: Character, method for handling ties ("efron" (default), "breslow", "exact"). | ||
#' - `iter.max`: Integer, max iterations for Cox fit (default: 20, range: 1-1000). | ||
#' - `eps`: Numeric, convergence threshold (default: 1e-9, range: 1e-12 to 1e-4). | ||
#' - `robust`: Logical, compute robust variance (default: FALSE). | ||
#' - `target_event`: Event type to model (default: second level if NULL). | ||
#' - `singular.ok`: Logical, allow singular predictors (default: TRUE). | ||
#' @section Predict Types: | ||
#' - `crank`: Continuous ranking (linear predictor). | ||
#' - `lp`: Linear predictor. | ||
#' - `distr`: Survival probabilities (matrix with times as colnames). | ||
#' @section Properties: | ||
#' - Supports observation weights via the task's weights role (set via `weights` argument or `$set_col_roles()`). | ||
#' @return For `predict()`, a list with `crank` (numeric), `lp` (numeric), and `distr` | ||
#' (matrix of survival probabilities with times as colnames). | ||
#' @section Methods: | ||
#' - `new()`: Alias for `initialize()`, creates a new instance of the learner with default parameters. | ||
#' - `train(task)`: Train the model on a survival task. | ||
#' - `predict(task)`: Predict on new data from a trained model. | ||
#' @importFrom R6 R6Class | ||
#' @importFrom mlr3proba LearnerSurv TaskSurv | ||
#' @importFrom paradox ps p_fct p_int p_dbl p_lgl p_uty | ||
#' @importFrom survival finegray coxph basehaz coxph.control | ||
#' @export | ||
LearnerSurvFineGrayCoxPH <- R6::R6Class("LearnerSurvFineGrayCoxPH", | ||
inherit = mlr3proba::LearnerSurv, | ||
public = list( | ||
#' @method initialize LearnerSurvFineGrayCoxPH | ||
#' @description | ||
#' Initialize a new Fine-Gray Cox PH learner with its default hyperparameter settings. | ||
#' No arguments are required; hyperparameters are defined and set via the `paradox` parameter set. | ||
initialize = function() { | ||
ps <- paradox::ps( | ||
ties = p_fct(default = "efron", levels = c("efron", "breslow", "exact"), tags = "train"), | ||
iter.max = p_int(default = 20L, lower = 1L, upper = 1000L, tags = "train"), | ||
eps = p_dbl(default = 1e-9, lower = 1e-12, upper = 1e-4, tags = "train"), | ||
robust = p_lgl(default = FALSE, tags = "train"), | ||
target_event = p_uty(default = NULL, tags = "train", | ||
custom_check = function(x) { | ||
if (is.null(x)) return(TRUE) | ||
if (is.character(x) || is.numeric(x)) return(TRUE) | ||
"target_event must be NULL, a character, or a numeric index" | ||
}), | ||
singular.ok = p_lgl(default = TRUE, tags = "train") | ||
) | ||
ps$values <- list( | ||
ties = "efron", | ||
iter.max = 20L, | ||
eps = 1e-9, | ||
robust = FALSE, | ||
target_event = NULL, | ||
singular.ok = TRUE | ||
) | ||
super$initialize( | ||
id = "surv.finegray_coxph", | ||
param_set = ps, | ||
feature_types = c("logical", "integer", "numeric", "factor"), | ||
predict_types = c("crank", "lp", "distr"), | ||
properties = "weights", | ||
packages = "survival", | ||
label = "Fine-Gray Competing Risks Model with CoxPH", | ||
man = "mlr3SurvUtils::LearnerSurvFineGrayCoxPH" | ||
) | ||
} | ||
), | ||
private = list( | ||
basehaz = NULL, | ||
.train = function(task) { | ||
#print("ParamSet in self:") | ||
#print(self$param_set) # Check if param_set exists | ||
pv <- self$param_set$get_values(tags = "train") | ||
#print("pv contents:") | ||
#print(pv) # Debug: Check what pv contains | ||
if (is.null(pv$iter.max) || !is.integer(pv$iter.max) || pv$iter.max < 1) { | ||
stop("Invalid iter.max: ", pv$iter.max) | ||
} | ||
row_ids <- task$row_ids | ||
full_data <- as.data.frame(task$backend$data(rows = row_ids, | ||
cols = c(task$target_names, task$feature_names))) | ||
features <- task$feature_names | ||
if (length(features) == 0) stop("No features provided!") | ||
time_col <- task$target_names[1] | ||
event_col <- task$target_names[2] | ||
event_levels <- task$levels()[[event_col]] | ||
full_data$id <- seq_len(nrow(full_data)) | ||
weights <- NULL | ||
if ("weights" %in% task$properties) { | ||
weights_data <- task$weights | ||
if (is.null(weights_data) || !"weight" %in% names(weights_data)) { | ||
stop("No weights defined in task") | ||
} | ||
all_weights <- weights_data$weight | ||
all_row_ids <- weights_data$row_id | ||
weight_map <- setNames(all_weights, all_row_ids) | ||
weights <- weight_map[as.character(row_ids)] | ||
if (any(is.na(weights))) stop("Missing weights for some row_ids") | ||
} | ||
form <- as.formula(paste("Surv(", time_col, ",", event_col, ") ~", | ||
paste(c(features, "id"), collapse = " + "))) | ||
if (length(event_levels) < 3) { | ||
stop("Event status must have at least 3 levels (censored, main event, competing risk)") | ||
} | ||
target_event <- if (is.null(pv$target_event)) { | ||
event_levels[2] | ||
} else { | ||
if (is.numeric(pv$target_event)) { | ||
event_levels[pv$target_event] | ||
} else { | ||
pv$target_event | ||
} | ||
} | ||
if (!target_event %in% event_levels) stop("target_event not in event levels") | ||
fg_data <- survival::finegray(form, data = full_data, etype = target_event, weights = weights) | ||
if ("weights" %in% task$properties) { | ||
matched_indices <- match(fg_data$id, as.integer(names(weight_map))) | ||
fg_data$fgwt <- weights[matched_indices] | ||
} | ||
cox_formula <- as.formula(paste("Surv(fgstart, fgstop, fgstatus) ~", | ||
paste(features, collapse = " + "))) | ||
model <- survival::coxph( | ||
cox_formula, data = fg_data, weights = fg_data$fgwt, | ||
control = coxph.control(eps = pv$eps, iter.max = pv$iter.max), | ||
robust = pv$robust, singular.ok = pv$singular.ok, ties = pv$ties | ||
) | ||
basehaz <- survival::basehaz(model, centered = FALSE) | ||
private$basehaz <- list(time = basehaz$time, cumhaz = basehaz$hazard) | ||
model | ||
}, | ||
.predict = function(task) { | ||
newdata <- as.data.frame(task$data(rows = task$row_ids, cols = task$feature_names)) | ||
lp <- predict(self$model, newdata = newdata, type = "lp") | ||
cumhaz <- private$basehaz$cumhaz | ||
time_order <- order(private$basehaz$time) | ||
surv <- exp(-outer(exp(lp), cumhaz, "*")) | ||
surv <- surv[, time_order, drop = FALSE] | ||
colnames(surv) <- private$basehaz$time[time_order] | ||
list(crank = lp, lp = lp, distr = surv) | ||
} | ||
) | ||
) |
74 changes: 74 additions & 0 deletions
74
tests/testthat/test-learner_survival_surv_finegray_coxph.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
library(testthat) | ||
library(mlr3) | ||
library(mlr3proba) | ||
|
||
test_that("surv.finegray_coxph trains and predicts", { | ||
task = TaskSurv$new( | ||
id = "test", | ||
backend = data.frame( | ||
time = c(1, 2, 3, 4, 5), | ||
event = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")), | ||
x1 = c(1, 2, 3, 4, 5) | ||
), | ||
time = "time", | ||
event = "event", | ||
type = "mstate" | ||
) | ||
|
||
learner = lrn("surv.finegray_coxph") | ||
expect_learner(learner) | ||
|
||
# Train | ||
learner$train(task) | ||
expect_true(!is.null(learner$model)) | ||
|
||
# Predict | ||
p = learner$predict(task) | ||
expect_s3_class(p, "PredictionSurv") | ||
expect_numeric(p$crank, len = 5) | ||
expect_numeric(p$lp, len = 5) | ||
expect_s3_class(p$distr, "Distribution") # Check for Distribution, not matrix | ||
expect_true(all(c("crank", "lp", "distr") %in% names(p))) # Check names are present | ||
}) | ||
|
||
test_that("surv.finegray_coxph handles weights", { | ||
task = TaskSurv$new( | ||
id = "test_weights", | ||
backend = data.frame( | ||
time = c(1, 2, 3, 4, 5), | ||
event = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")), | ||
x1 = c(1, 2, 3, 4, 5), | ||
weight = c(1, 2, 1, 2, 1) # Add weights to backend | ||
), | ||
time = "time", | ||
event = "event", | ||
type = "mstate" | ||
) | ||
task$set_col_roles("weight", roles = "weight") # Set weight role | ||
|
||
learner = lrn("surv.finegray_coxph") | ||
learner$train(task) | ||
p = learner$predict(task) | ||
expect_s3_class(p, "PredictionSurv") | ||
expect_numeric(p$crank, len = 5) | ||
expect_numeric(p$lp, len = 5) | ||
expect_s3_class(p$distr, "Distribution") | ||
expect_true(all(c("crank", "lp", "distr") %in% names(p))) | ||
}) | ||
|
||
test_that("surv.finegray_coxph with mstate task", { | ||
data = data.frame( | ||
time = c(1, 2, 3, 4, 5), | ||
status = factor(c(0, 1, 2, 0, 1), levels = c(0, 1, 2), labels = c("censored", "event1", "event2")), | ||
x = c(1, 2, 3, 4, 5) | ||
) | ||
task = TaskSurv$new("mstate_test", backend = data, time = "time", event = "status", type = "mstate") | ||
learner = lrn("surv.finegray_coxph") | ||
learner$train(task) | ||
p = learner$predict(task) | ||
expect_s3_class(p, "PredictionSurv") | ||
expect_numeric(p$crank, len = 5) | ||
expect_numeric(p$lp, len = 5) | ||
expect_s3_class(p$distr, "Distribution") | ||
expect_true(all(c("crank", "lp", "distr") %in% names(p))) | ||
}) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new
mlr3proba
update, you won't be ableto create aTaskSurv
like this (with 3 event types).TaskCompRisks
will be responsible for that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @bblodfon for your responses.
I noticed that in your version of
mlr3proba
theTaskSurv
does not supporttype="counting"
. Is there another task generator that takes this responsibility?I am also wondering whether a name similar to
TaskMultStates
for this type of task generator would be more "inclusive" and consistent with a convention used inSurv
definition instead ofTaskCompRisks
.Please consider adding an argument
cmp_event_labels
that allows the user to pass this information.Thank you again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
It's a new Task Class =>
TaskCompRisks
, please see PR here, and I guess more helpful would be to see the tests.We noticed that as well, but there is a lot of stuff that is different (though conceptually competing risks is a subcase of multi state modeling), especially the prediction types (state probabilities can go up and down across time, CIF is always increasing), so we decided to split for now.
I have, it just was so much easier to have the event types with intergers 0,1,2,3,... due to other subsetting/filtering issues with
survival::Surv()
that is used inside the competing risks task. I could store the event column character mapping stored somewhere I guess, but it doesn't change anything in the modeling stuff we do...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the learner here to a
lrn("cmprsk.finegray")
that outputs a CIF as mentioned in #416 would be for the best! I included the Aalen-Johansen estimator in the mlr3proba PR so that would be a good template to follow.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE: "I noticed that in your version of mlr3proba the TaskSurv does not support
type="counting"
. Is there another task generator that takes this responsibility?'This issue pertains to discussion. More specifically the code
worked on Oct 2nd, 2024, but now the new version of
mlr3proba
does not supporttype="counting"
option and throws an error:Error in .__TaskSurv__initialize(self = self, private = private, super = super, :
Assertion on 'type' failed: Must be element of set {'right','left','interval'}, but is 'counting'.
Would you consider reinstating
type="counting"
option?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry for now answering directly - simply put, removing the
type = "counting"
is a design choice.We are trying to see how we can address different censoring types and data formats (time-dependent covariates, start/stop data, etc.) and task types (multi-state vs single-event for example). The
type
argument above was about defining the type of censoring, but"counting"
refers to the general[start, stop)
data format (terminology was fromSurv(...)
which kinda mixes all these things). i.e. usually you also need anid
column as one observation can have multiple rows in such a long format dataset. Not easy to see how we can reconcile this but definitely not impossible, we just need to make some design choices inmlr3
/mlr3proba
at some point.