-
-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PEM Pipeline #417
base: main
Are you sure you want to change the base?
PEM Pipeline #417
Changes from 1 commit
6eb5a6e
4d49fa3
d786d08
b7bc0c6
5d6b61b
c19e4bb
717478d
976d9c0
35745f0
e2a6c21
5a75617
a3bdbf5
6bf7332
690b5c6
f341679
d684cd0
b14e678
f00f1f9
7a848bf
2fc11a4
c035e9b
cb52f0b
4e509e4
82b8af7
0532d71
ec0ffdd
a5402c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
#' | ||
#' @description | ||
#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous | ||
#' time into multiple time intervals for each observation. | ||
#' time into multiple time intervals for each observation. The survival data set | ||
#' stored in [TaskSurv] is transformed into Piece-wise Exponential Data (PED) format | ||
#' which in turn forms the backend for [TaskRegr][mlr3::TaskRegr]. | ||
#' This transformation creates a new target variable `PEM_status` that indicates | ||
bblodfon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#' whether an event occurred within each time interval. | ||
#' | ||
|
@@ -31,8 +33,8 @@ | |
#' The "transformed_data" is an empty [data.table][data.table::data.table]. | ||
#' | ||
#' During prediction, the "input" [TaskSurv] is transformed to the "output" | ||
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target and the `"tend"` | ||
#' as well as `"offset"` feature included. | ||
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target, while `"tend"` | ||
#' and `"offset"` are included as features. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mroe accurately: offset in not a |
||
#' The "transformed_data" is a [data.table] with columns the `"PEM_status"` | ||
#' target of the "output" task, the `"id"` (original observation ids), | ||
#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval). | ||
|
@@ -53,12 +55,10 @@ | |
#' If `cut` is unspecified, this will be the last possible event time. | ||
#' All event times after `max_time` will be administratively censored at `max_time.` | ||
#' Needs to be greater than the minimum event time in the given task. | ||
#' | ||
#' @examples | ||
#' | ||
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE) | ||
#' * `ped_formula` | ||
#' TODO | ||
#' @examplesIf (mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3extralearners"), quietly = TRUE)) | ||
#' \dontrun{ | ||
#' # Update documentation to match PEM | ||
#' library(mlr3) | ||
#' library(mlr3learners) | ||
#' library(mlr3pipelines) | ||
|
@@ -73,11 +73,10 @@ | |
#' # the end time points of the discrete time intervals | ||
#' unique(task_regr$data(cols = "tend"))[[1L]] | ||
#' | ||
#' # train a classification learner | ||
#' learner = lrn("classif.log_reg", predict_type = "prob") | ||
#' # train a regression learner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... that supports poisson regression.... |
||
#' learner = lrn("regr.gam") # won't run unless learner can accept offset column role | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: when I finish the Also correct the example + make it a bit more interesting, e.g. => |
||
#' learner$train(task_regr) | ||
#' } | ||
#' } | ||
#' } | ||
#' | ||
#' | ||
#' @family PipeOps | ||
|
@@ -95,7 +94,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", | |
max_time = p_dbl(0, default = NULL, special_vals = list(NULL)), | ||
censor_code = p_int(0L), | ||
min_events = p_int(1L), | ||
form = p_uty(tags = 'train') | ||
ped_formula = p_uty(tags = 'train', default = NULL) | ||
#pammtools arguments: transitions etc. | ||
) | ||
super$initialize( | ||
|
@@ -139,16 +138,16 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", | |
"max_time must be greater than the minimum event time.") | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing redundant empty lines in all code would be nice - some space is good, more space is unnecessary |
||
# To-Do: Extend to a more general formulation for competing risks and msm | ||
# Issue: We pass form (e.g. Surv(time, status) ~ .) which currently serves to correctly transform the data into ped format | ||
# but doesn't serve any other purpose yet. For ML learners, such as xgb, the covariate structure is passed to the pipeline via rhs not form. | ||
long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time) | ||
self$state$cut = attributes(long_data)$trafo_args$cut | ||
ped_formula = self$param_set$values$ped_formula | ||
if (is.null(ped_formula)){ | ||
ped_formula = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") | ||
} | ||
long_data = pammtools::as_ped(data = data, formula = ped_formula, cut = cut, max_time = max_time) | ||
long_data = as.data.table(long_data) | ||
|
||
|
||
self$state$cut = attributes(long_data)$trafo_args$cut | ||
|
||
long_data = as.data.table(long_data) | ||
setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM | ||
setnames(long_data, old = "ped_status", new = "PEM_status") | ||
|
||
# remove some columns from `long_data` | ||
long_data[, c("tstart", "interval") := NULL] | ||
|
@@ -161,6 +160,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", | |
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, | ||
target = "PEM_status") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a bit more proper indentation style => |
||
task_PEM$set_col_roles("id", roles = "original_ids") | ||
task_PEM$set_col_roles('offset', roles = "offset") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: no |
||
|
||
list(task_PEM, data.table()) | ||
}, | ||
|
@@ -181,9 +181,15 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", | |
|
||
status = data[[event_var]] | ||
data[[event_var]] = 1 | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of extra space: a good informative comment! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Frankly, not quite sure what's the purpose of |
||
ped_formula = self$param_set$values$ped_formula | ||
if (is.null(ped_formula)){ | ||
ped_formula = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") | ||
} | ||
long_data = pammtools::as_ped(data = data, formula = ped_formula, cut = cut, max_time = max_time) | ||
long_data = as.data.table(long_data) | ||
|
||
|
||
long_data = as.data.table(pammtools::as_ped(data, formula = self$param_set$values$form, cut = cut)) | ||
setnames(long_data, old = "ped_status", new = "PEM_status") | ||
|
||
PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table | ||
|
@@ -204,6 +210,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", | |
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, | ||
target = "PEM_status") | ||
task_PEM$set_col_roles("id", roles = "original_ids") | ||
task_PEM$set_col_roles('offset', roles = "offset") | ||
|
||
# map observed times back | ||
reps = table(long_data$id) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent! I would suggest 1) remove the time-dependency as we don't support it, ie
x
instead ofx(t)
2) describe a bitg
function? 3) add reference via the bibtex file => Andreas 2018 paper (A generalized additive model approach to time-to-event analysis)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some additional descriptions for clarification. At the same time, I am wondering whether PipeOpPred is the appropriate place for these mathematical explanations. I feel, as a whole, they might be more appropriate in the doc of the pipeline, with exception of the backtransform portion, i.e. how we get surv probs from hazards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point - it makes sense to put the related doc directly in the class that implementes what the docs says - and provide the links to that in the pipeline, but up to you if you want to add extra math doc somewhere (even a bit duplicated), always welcome!