-
-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PEM Pipeline #417
base: main
Are you sure you want to change the base?
PEM Pipeline #417
Changes from 1 commit
6eb5a6e
4d49fa3
d786d08
b7bc0c6
5d6b61b
c19e4bb
717478d
976d9c0
35745f0
e2a6c21
5a75617
a3bdbf5
6bf7332
690b5c6
f341679
d684cd0
b14e678
f00f1f9
7a848bf
2fc11a4
c035e9b
cb52f0b
4e509e4
82b8af7
0532d71
ec0ffdd
a5402c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#' @title PipeOpPredRegrSurvPEM | ||
#' @name mlr_pipeops_trafopred_regrsurv_PEM | ||
#' | ||
#' @description | ||
#' Transform [PredictionRegr] to [PredictionSurv]. | ||
#' | ||
#' @section Dictionary: | ||
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the | ||
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] | ||
#' or with the associated sugar function [mlr3pipelines::po()]: | ||
#' ``` | ||
#' PipeOpPredRegrSurvPEM$new() | ||
#' mlr_pipeops$get("trafopred_regrsurv_PEM") | ||
#' po("trafopred_regrsurv_PEM") | ||
#' ``` | ||
#' | ||
#' @section Input and Output Channels: | ||
#' The input is a [PredictionRegr] and a [data.table][data.table::data.table] | ||
#' with the transformed data both generated by [PipeOpTaskSurvRegrPEM]. | ||
#' The output is the input [PredictionRegr] transformed to a [PredictionSurv]. | ||
#' Only works during prediction phase. | ||
#' | ||
#' @family PipeOps | ||
#' @family Transformation PipeOps | ||
#' @export | ||
PipeOpPredRegrSurvPEM = R6Class( | ||
"PipeOpPredRegrSurvPEM", | ||
inherit = mlr3pipelines::PipeOp, | ||
|
||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
#' @param id (character(1))\cr | ||
#' Identifier of the resulting object. | ||
initialize = function(id = "trafopred_regrsurv_PEM") { | ||
super$initialize( | ||
id = id, | ||
input = data.table( | ||
name = c("input", "transformed_data"), | ||
train = c("NULL", "data.table"), | ||
predict = c("PredictionRegr", "data.table") | ||
), | ||
output = data.table( | ||
name = "output", | ||
train = "NULL", | ||
predict = "PredictionSurv" | ||
) | ||
) | ||
} | ||
), | ||
|
||
private = list( | ||
.predict = function(input) { | ||
pred = input[[1]] | ||
data = input[[2]] | ||
assert_true(!is.null(pred$response)) | ||
# probability of having the event (1) in each respective interval | ||
# is the discrete-time hazard | ||
data = cbind(data, dt_hazard = pred$response) | ||
|
||
# From theory, convert hazards to surv as prod(1 - h(t)) | ||
rows_per_id = nrow(data) / length(unique(data$id)) | ||
surv = t(vapply(unique(data$id), function(unique_id) { | ||
1 - cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])) | ||
}, numeric(rows_per_id))) | ||
|
||
unique_end_times = sort(unique(data$tend)) | ||
# coerce to distribution and crank | ||
pred_list = .surv_return(times = unique_end_times, surv = surv) | ||
|
||
# select the real tend values by only selecting the last row of each id | ||
# basically a slightly more complex unique() | ||
real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0] | ||
|
||
ids = unique(data$id) | ||
# select last row for every id => observed times | ||
id = disc_status = NULL # to fix note | ||
data = data[, .SD[.N, list(disc_status)], by = id] | ||
|
||
# create prediction object | ||
p = PredictionSurv$new( | ||
row_ids = ids, | ||
crank = pred_list$crank, distr = pred_list$distr, | ||
truth = Surv(real_tend, as.integer(as.character(data$disc_status)))) | ||
|
||
list(p) | ||
}, | ||
|
||
.train = function(input) { | ||
self$state = list() | ||
list(input) | ||
} | ||
) | ||
) | ||
register_pipeop("trafopred_regrsurv_PEM", PipeOpPredRegrSurvPEM) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -659,6 +659,81 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL, | |
gr | ||
} | ||
|
||
#' @name mlr_graphs_survtoregr_PEM | ||
#' @title Survival to Poisson Regression Reduction Pipeline | ||
#' @description Wrapper around multiple [PipeOp][mlr3pipelines::PipeOp]s to help in creation | ||
#' of complex survival reduction methods. | ||
#' | ||
#' @param learner [LearnerRegr][mlr3::LearnerRegr]\cr | ||
#' Regression learner to fit the transformed [TaskRegr][mlr3::TaskRegr]. | ||
#' `learner` must be able to handle `offset`. | ||
#' @param cut `numeric()`\cr | ||
#' Split points, used to partition the data into intervals. | ||
#' If unspecified, all unique event times will be used. | ||
#' If `cut` is a single integer, it will be interpreted as the number of equidistant | ||
#' intervals from 0 until the maximum event time. | ||
#' @param max_time `numeric(1)`\cr | ||
#' If cut is unspecified, this will be the last possible event time. | ||
#' All event times after max_time will be administratively censored at max_time. | ||
#' @param graph_learner `logical(1)`\cr | ||
#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a | ||
#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`. | ||
#' | ||
#' @details | ||
#' The pipeline consists of the following steps: | ||
#' \enumerate{ | ||
#' \item [PipeOpTaskSurvRegrPEM] Converts [TaskSurv] to a [TaskRegr][mlr3::TaskRegr]. | ||
#' \item A [LearnerRegr] is fit and predicted on the new `TaskRegr`. | ||
#' \item [PipeOpPredRegrSurvPEM] transforms the resulting [PredictionRegr][mlr3::PredictionRegr] | ||
#' to [PredictionSurv]. | ||
#' } | ||
#' | ||
#' @return [mlr3pipelines::Graph] or [mlr3pipelines::GraphLearner] | ||
#' @family pipelines | ||
#' | ||
#' @examples | ||
#' \dontrun{ | ||
#' if (requireNamespace("mlr3pipelines", quietly = TRUE) && | ||
#' requireNamespace("mlr3learners", quietly = TRUE)) { | ||
#' | ||
#' library(mlr3) | ||
#' library(mlr3learners) | ||
#' library(mlr3pipelines) | ||
#' | ||
#' task = tsk("lung") | ||
#' part = partition(task) | ||
#' | ||
#' grlrn = ppl( | ||
#' "survtoregr_PEM", | ||
#' learner = lrn("regr.xgboost") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. betetr example: show encoding of factor, maybe some modelmatrix trafo inside the learner? (consult Andreas) |
||
#' ) | ||
#' grlrn$train(task, row_ids = part$train) | ||
#' grlrn$predict(task, row_ids = part$test) | ||
#' } | ||
#' } | ||
#' @export | ||
pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL, | ||
rhs = NULL, graph_learner = FALSE) { | ||
# TODO: add assertions | ||
|
||
gr = mlr3pipelines::Graph$new() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don;t need the |
||
gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time)) | ||
gr$add_pipeop(mlr3pipelines::po("learner", learner)) | ||
gr$add_pipeop(mlr3pipelines::po("nop")) | ||
gr$add_pipeop(mlr3pipelines::po("trafopred_regrsurv_PEM")) | ||
|
||
gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = learner$id, src_channel = "output", dst_channel = "input") | ||
gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "nop", src_channel = "transformed_data", dst_channel = "input") | ||
gr$add_edge(src_id = learner$id, dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "input") | ||
gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data") | ||
|
||
if (graph_learner) { | ||
gr = mlr3pipelines::GraphLearner$new(gr) | ||
} | ||
|
||
gr | ||
} | ||
|
||
register_graph("survaverager", pipeline_survaverager) | ||
register_graph("survbagging", pipeline_survbagging) | ||
register_graph("crankcompositor", pipeline_crankcompositor) | ||
|
@@ -667,3 +742,4 @@ register_graph("responsecompositor", pipeline_responsecompositor) | |
register_graph("probregr", pipeline_probregr) | ||
register_graph("survtoregr", pipeline_survtoregr) | ||
register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime) | ||
register_graph("survregr_PEM", pipeline_survtoregr_PEM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Task: I think this is the part that sometimes results in surv probabilities that are not descreasing, right?
Example:
Can we please identify why that is happening? is it some sort of arithmetic instability thing? or some calculations above with the offset are wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always need to specify the family/objective (depending on the learner) as "poisson" to establish the exponential link between hazard and the learned model. So this works as intended, as long as the learner is correctly specified, however, the pipeline does not check whether that has been done. I suppose one can check for this, but not sure if some learners use arguments other than features and family to specify the distributional assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nice! We should document that in the example of the pipeline (and the vignette). If you want to do the extra leg, you can create a small function that checks the learner in the pipeline for
family
/objective
parameters, and if it doesn't find the keywordpoisson
, throws a warning "PEM works correcty with learners that support poisson regression". Andreas' list of candidates learners is enough (i.e. is not that many either way)