Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEM Pipeline #417

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6eb5a6e
draft task conversion pipeop
studener Sep 24, 2024
4d49fa3
draft pipeop + pipeline
studener Sep 26, 2024
d786d08
update pred conversion pipeop
studener Sep 26, 2024
b7bc0c6
added modelmatrix pipeop to PEM pipeline, changed variable naming to …
markusgoeswein Oct 29, 2024
5d6b61b
added col_role original_ids to regression tasks
markusgoeswein Nov 14, 2024
c19e4bb
changed id column role to original_ids
markusgoeswein Nov 14, 2024
717478d
added additional arguments to TaskSurvRegrPEM to enable more complex …
markusgoeswein Nov 21, 2024
976d9c0
form is now to be passed without quotation marks
markusgoeswein Dec 6, 2024
35745f0
resolve merge conflict with main, before merging
markusgoeswein Jan 31, 2025
e2a6c21
resolve merge conflict in R\piplines.R
markusgoeswein Jan 31, 2025
5a75617
setting up unit tests for the PEM pipeline
markusgoeswein Feb 21, 2025
a3bdbf5
update function doc
markusgoeswein Feb 25, 2025
6bf7332
update remotes for offset support
markusgoeswein Feb 25, 2025
690b5c6
remove ped_formula argument in favour of automatically parsing it
markusgoeswein Mar 3, 2025
f341679
add assert and set use_pred_offset to FALSE if not done so
markusgoeswein Mar 3, 2025
d684cd0
Regenerate Rd files using devtools::document()
markusgoeswein Mar 3, 2025
b14e678
adjust assertions in pipeline
markusgoeswein Mar 4, 2025
f00f1f9
Merge pull request #436 from mlr-org/main
markusgoeswein Mar 4, 2025
7a848bf
setting up test_PEM.R and adjustments to tests in pipelines
markusgoeswein Mar 4, 2025
2fc11a4
change lrn() from regr.xgboost to regr.glmnet
markusgoeswein Mar 13, 2025
c035e9b
update DESCRIPTION
markusgoeswein Mar 13, 2025
cb52f0b
add glmnet to suggests for PEM pipeline tests
markusgoeswein Mar 13, 2025
4e509e4
included require_namespace('glmnet, ...) for PEM pipeline tests
markusgoeswein Mar 13, 2025
82b8af7
minor fix
markusgoeswein Mar 13, 2025
0532d71
set PEM to lowercase in PipeOpTask and PipeOpPred, update DESCRIPTION…
markusgoeswein Mar 25, 2025
ec0ffdd
temporary name change of PEM pipeops
markusgoeswein Mar 25, 2025
a5402c3
man files are renamed with lowercase pem
markusgoeswein Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
draft pipeop + pipeline
  • Loading branch information
studener committed Sep 26, 2024
commit 4d49fa33ad1bf258934015992ea7b400fabbe36f
95 changes: 95 additions & 0 deletions R/PipeOpPredRegrSurvPEM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#' @title PipeOpPredRegrSurvPEM
#' @name mlr_pipeops_trafopred_regrsurv_PEM
#'
#' @description
#' Transform [PredictionRegr] to [PredictionSurv].
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredRegrSurvPEM$new()
#' mlr_pipeops$get("trafopred_regrsurv_PEM")
#' po("trafopred_regrsurv_PEM")
#' ```
#'
#' @section Input and Output Channels:
#' The input is a [PredictionRegr] and a [data.table][data.table::data.table]
#' with the transformed data both generated by [PipeOpTaskSurvRegrPEM].
#' The output is the input [PredictionRegr] transformed to a [PredictionSurv].
#' Only works during prediction phase.
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredRegrSurvPEM = R6Class(
"PipeOpPredRegrSurvPEM",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param id (character(1))\cr
#' Identifier of the resulting object.
initialize = function(id = "trafopred_regrsurv_PEM") {
super$initialize(
id = id,
input = data.table(
name = c("input", "transformed_data"),
train = c("NULL", "data.table"),
predict = c("PredictionRegr", "data.table")
),
output = data.table(
name = "output",
train = "NULL",
predict = "PredictionSurv"
)
)
}
),

private = list(
.predict = function(input) {
pred = input[[1]]
data = input[[2]]
assert_true(!is.null(pred$response))
# probability of having the event (1) in each respective interval
# is the discrete-time hazard
data = cbind(data, dt_hazard = pred$response)

# From theory, convert hazards to surv as prod(1 - h(t))
rows_per_id = nrow(data) / length(unique(data$id))
surv = t(vapply(unique(data$id), function(unique_id) {
1 - cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))
}, numeric(rows_per_id)))

unique_end_times = sort(unique(data$tend))
# coerce to distribution and crank
pred_list = .surv_return(times = unique_end_times, surv = surv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Task: I think this is the part that sometimes results in surv probabilities that are not descreasing, right?
Example:

task = tsk("lung")
l = po("encode") %>>% lrn("regr.xgboost") |> as_learner()
pem = ppl("survtoregr_PEM", learner = l)
pem$train(task)$predict(task)

Can we please identify why that is happening? is it some sort of arithmetic instability thing? or some calculations above with the offset are wrong?

Copy link
Collaborator

@markusgoeswein markusgoeswein Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task = tsk("lung")
l = po("encode") %>>% lrn("regr.xgboost", objective = 'count:poisson') |> as_learner()
pem = ppl("survtoregr_PEM", learner = l)
pem$train(task)
pred = pem$predict(task)

Always need to specify the family/objective (depending on the learner) as "poisson" to establish the exponential link between hazard and the learned model. So this works as intended, as long as the learner is correctly specified, however, the pipeline does not check whether that has been done. I suppose one can check for this, but not sure if some learners use arguments other than features and family to specify the distributional assumption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice! We should document that in the example of the pipeline (and the vignette). If you want to do the extra leg, you can create a small function that checks the learner in the pipeline for family/objective parameters, and if it doesn't find the keyword poisson, throws a warning "PEM works correcty with learners that support poisson regression". Andreas' list of candidates learners is enough (i.e. is not that many either way)


# select the real tend values by only selecting the last row of each id
# basically a slightly more complex unique()
real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0]

ids = unique(data$id)
# select last row for every id => observed times
id = disc_status = NULL # to fix note
data = data[, .SD[.N, list(disc_status)], by = id]

# create prediction object
p = PredictionSurv$new(
row_ids = ids,
crank = pred_list$crank, distr = pred_list$distr,
truth = Surv(real_tend, as.integer(as.character(data$disc_status))))

list(p)
},

.train = function(input) {
self$state = list()
list(input)
}
)
)
register_pipeop("trafopred_regrsurv_PEM", PipeOpPredRegrSurvPEM)
76 changes: 76 additions & 0 deletions R/pipelines.R
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,81 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL,
gr
}

#' @name mlr_graphs_survtoregr_PEM
#' @title Survival to Poisson Regression Reduction Pipeline
#' @description Wrapper around multiple [PipeOp][mlr3pipelines::PipeOp]s to help in creation
#' of complex survival reduction methods.
#'
#' @param learner [LearnerRegr][mlr3::LearnerRegr]\cr
#' Regression learner to fit the transformed [TaskRegr][mlr3::TaskRegr].
#' `learner` must be able to handle `offset`.
#' @param cut `numeric()`\cr
#' Split points, used to partition the data into intervals.
#' If unspecified, all unique event times will be used.
#' If `cut` is a single integer, it will be interpreted as the number of equidistant
#' intervals from 0 until the maximum event time.
#' @param max_time `numeric(1)`\cr
#' If cut is unspecified, this will be the last possible event time.
#' All event times after max_time will be administratively censored at max_time.
#' @param graph_learner `logical(1)`\cr
#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a
#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`.
#'
#' @details
#' The pipeline consists of the following steps:
#' \enumerate{
#' \item [PipeOpTaskSurvRegrPEM] Converts [TaskSurv] to a [TaskRegr][mlr3::TaskRegr].
#' \item A [LearnerRegr] is fit and predicted on the new `TaskRegr`.
#' \item [PipeOpPredRegrSurvPEM] transforms the resulting [PredictionRegr][mlr3::PredictionRegr]
#' to [PredictionSurv].
#' }
#'
#' @return [mlr3pipelines::Graph] or [mlr3pipelines::GraphLearner]
#' @family pipelines
#'
#' @examples
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE) &&
#' requireNamespace("mlr3learners", quietly = TRUE)) {
#'
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3pipelines)
#'
#' task = tsk("lung")
#' part = partition(task)
#'
#' grlrn = ppl(
#' "survtoregr_PEM",
#' learner = lrn("regr.xgboost")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

betetr example: show encoding of factor, maybe some modelmatrix trafo inside the learner? (consult Andreas)

#' )
#' grlrn$train(task, row_ids = part$train)
#' grlrn$predict(task, row_ids = part$test)
#' }
#' }
#' @export
pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL,
rhs = NULL, graph_learner = FALSE) {
# TODO: add assertions

gr = mlr3pipelines::Graph$new()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don;t need the mlr3pipelines:: as we import the package now, see other pipelines (so remove)

gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time))
gr$add_pipeop(mlr3pipelines::po("learner", learner))
gr$add_pipeop(mlr3pipelines::po("nop"))
gr$add_pipeop(mlr3pipelines::po("trafopred_regrsurv_PEM"))

gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = learner$id, src_channel = "output", dst_channel = "input")
gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "nop", src_channel = "transformed_data", dst_channel = "input")
gr$add_edge(src_id = learner$id, dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "input")
gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data")

if (graph_learner) {
gr = mlr3pipelines::GraphLearner$new(gr)
}

gr
}

register_graph("survaverager", pipeline_survaverager)
register_graph("survbagging", pipeline_survbagging)
register_graph("crankcompositor", pipeline_crankcompositor)
Expand All @@ -667,3 +742,4 @@ register_graph("responsecompositor", pipeline_responsecompositor)
register_graph("probregr", pipeline_probregr)
register_graph("survtoregr", pipeline_survtoregr)
register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime)
register_graph("survregr_PEM", pipeline_survtoregr_PEM)