-
-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PEM Pipeline #417
base: main
Are you sure you want to change the base?
PEM Pipeline #417
Conversation
…PEM, fixed minor bugs in PipeOp...PEM
…risk scenarios in the future, formula is now passed via the form argument during pipeline creation
allow cloning of measures
LinkingTo: | ||
Rcpp | ||
Remotes: | ||
xoopR/distr6, | ||
xoopR/param6, | ||
xoopR/set6 | ||
xoopR/set6, | ||
mlr-org/mlr3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remember to remove Remotes, mlr3learners
new version will soon be on CRAN as well (mlr3extralearners
is not on CRAN so its always the latest version from github)
#' \deqn{S(t | \mathbf{x}) = \exp \left( - \int_{0}^{t} \lambda(s | \mathbf{x}) \, ds \right) = \exp \left( - \sum_{j = 1}^{J} \lambda(j | \mathbf{x}) d_j\, \right),} | ||
#' where \eqn{d_j} specifies the duration of interval \eqn{j}, | ||
#' | ||
#' we compute the survival probability from the predicted hazards. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent! I would suggest 1) remove the time-dependency as we don't support it, ie x
instead of x(t)
2) describe a bit g
function? 3) add reference via the bibtex file => Andreas 2018 paper (A generalized additive model approach to time-to-event analysis)
|
||
unique_end_times = sort(unique(data$tend)) | ||
# coerce to distribution and crank | ||
pred_list = .surv_return(times = unique_end_times, surv = surv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Task: I think this is the part that sometimes results in surv probabilities that are not descreasing, right?
Example:
task = tsk("lung")
l = po("encode") %>>% lrn("regr.xgboost") |> as_learner()
pem = ppl("survtoregr_PEM", learner = l)
pem$train(task)$predict(task)
Can we please identify why that is happening? is it some sort of arithmetic instability thing? or some calculations above with the offset are wrong?
#' time into multiple time intervals for each observation. The survival data set | ||
#' stored in [TaskSurv] is transformed into Piece-wise Exponential Data (PED) format | ||
#' which in turn forms the backend for [TaskRegr][mlr3::TaskRegr]. | ||
#' This transformation creates a new target variable `PEM_status` that indicates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace everywhere in all pipeops and pipelines: PEM_status
=> pem_status
#' The target column is named `"PEM_status"` and indicates whether an event occurred | ||
#' in each time interval. | ||
#' An additional feature named `"tend"` contains the end time point of each interval. | ||
#' Lastly, the "output" task has an offset column `"offset"`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more precisely: has a column with col_role offset
which is the ... log of something?
#' [TaskRegr][mlr3::TaskRegr]. | ||
#' The target column is named `"PEM_status"` and indicates whether an event occurred | ||
#' in each time interval. | ||
#' An additional feature named `"tend"` contains the end time point of each interval. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...numeric feature... (please verify) => add this also in the DiscTime pipeop
#' | ||
#' During prediction, the "input" [TaskSurv] is transformed to the "output" | ||
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target, while `"tend"` | ||
#' and `"offset"` are included as features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mroe accurately: offset in not a feature
, ie it doesn;'t havbe the col_role
feature, but the offset
one
#' unique(task_regr$data(cols = "tend"))[[1L]] | ||
#' | ||
#' # train a regression learner | ||
#' learner = lrn("regr.gam") # won't run unless learner can accept offset column role |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: when I finish the mlr3extralearners
PR, we can safely remove this comment here.
Also correct the example + make it a bit more interesting, e.g. => l = lrn("regr.gam", formula = pem_status ~ s(age) + s(tend), family = "poisson")
=> you definitely need the family poisson
argument here
#' # the end time points of the discrete time intervals | ||
#' unique(task_regr$data(cols = "tend"))[[1L]] | ||
#' | ||
#' # train a regression learner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... that supports poisson regression....
assert(max_time > data[get(event_var) == 1, min(get(time_var))], | ||
"max_time must be greater than the minimum event time.") | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing redundant empty lines in all code would be nice - some space is good, more space is unnecessary
long_data[, id := ids] | ||
|
||
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, | ||
target = "PEM_status") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit more proper indentation style => target
should be below new(
<= here, please check all code for this
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, | ||
target = "PEM_status") | ||
task_PEM$set_col_roles("id", roles = "original_ids") | ||
task_PEM$set_col_roles('offset', roles = "offset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: no '
anywhere in the code please, use only "
!
status = data[[event_var]] | ||
data[[event_var]] = 1 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of extra space: a good informative comment!
@@ -51,7 +51,8 @@ register_reflections = function() { | |||
|
|||
x$task_col_roles$surv = x$task_col_roles$regr | |||
x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum") | |||
x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time | |||
x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids"))# for discrete time | |||
x$task_col_roles$regr = unique(c(x$task_col_roles$regr, "original_ids")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# for pem
#' @param graph_learner `logical(1)`\cr | ||
#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a | ||
#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`. | ||
#' @param rhs (`character(1)`)\cr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove! also in disc time pipeline!
#' | ||
#' grlrn = ppl( | ||
#' "survtoregr_PEM", | ||
#' learner = lrn("regr.xgboost") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
betetr example: show encoding of factor, maybe some modelmatrix trafo inside the learner? (consult Andreas)
rhs = NULL, graph_learner = FALSE) { | ||
|
||
assert_true("offset" %in% learner$properties) | ||
assert_learner(learner, task_type = "regr") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine the two: assert_learner()
can check for properties too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe disc time pipeline has the same?
} | ||
} | ||
|
||
gr = mlr3pipelines::Graph$new() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don;t need the mlr3pipelines::
as we import the package now, see other pipelines (so remove)
gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data") | ||
|
||
|
||
if (!is.null(rhs)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove => also in disc time!
@@ -0,0 +1,75 @@ | |||
test_that("PipeOpTaskSurvRegrPEM", { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEM
=> pem
task = tsk('rats') | ||
# for this section, select only numeric covariates, | ||
# as 'regr.glmnet' does not automatically handle factor type variables | ||
task$select(c('litter', 'rx')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or po("encode")
!
expect_class(grlrn, "GraphLearner") | ||
suppressWarnings(grlrn$train(task)) | ||
p = grlrn$predict(task) | ||
expect_prediction_surv(p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check that ncol(p$data$distr) == 3
? and exactly the specific cut points? (if I recall correctly that's the time points used)
p = grlrn$predict(task) | ||
expect_prediction_surv(p) | ||
|
||
# Test with rhs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactor with modelamtrix
as a pipeop (As rhs
is removed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better with the gam
when mlr3extarlearners PR is finished...
|
||
private = list( | ||
.train = function(input) { | ||
task = input[[1L]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to experiment and implement the validation stuff for xgboost
, here is a bit of what is happening: task
will have a predefined validation task here, which is not transformed. what we need to do is something like:
transformed_internal_valid_task = private$.train(list(task$internal_valid_task))
task$internal_valid_task = transformed_internal_valid_task
and go on transforming the task
No description provided.