Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEM Pipeline #417

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6eb5a6e
draft task conversion pipeop
studener Sep 24, 2024
4d49fa3
draft pipeop + pipeline
studener Sep 26, 2024
d786d08
update pred conversion pipeop
studener Sep 26, 2024
b7bc0c6
added modelmatrix pipeop to PEM pipeline, changed variable naming to …
markusgoeswein Oct 29, 2024
5d6b61b
added col_role original_ids to regression tasks
markusgoeswein Nov 14, 2024
c19e4bb
changed id column role to original_ids
markusgoeswein Nov 14, 2024
717478d
added additional arguments to TaskSurvRegrPEM to enable more complex …
markusgoeswein Nov 21, 2024
976d9c0
form is now to be passed without quotation marks
markusgoeswein Dec 6, 2024
35745f0
resolve merge conflict with main, before merging
markusgoeswein Jan 31, 2025
e2a6c21
resolve merge conflict in R\piplines.R
markusgoeswein Jan 31, 2025
5a75617
setting up unit tests for the PEM pipeline
markusgoeswein Feb 21, 2025
a3bdbf5
update function doc
markusgoeswein Feb 25, 2025
6bf7332
update remotes for offset support
markusgoeswein Feb 25, 2025
690b5c6
remove ped_formula argument in favour of automatically parsing it
markusgoeswein Mar 3, 2025
f341679
add assert and set use_pred_offset to FALSE if not done so
markusgoeswein Mar 3, 2025
d684cd0
Regenerate Rd files using devtools::document()
markusgoeswein Mar 3, 2025
b14e678
adjust assertions in pipeline
markusgoeswein Mar 4, 2025
f00f1f9
Merge pull request #436 from mlr-org/main
markusgoeswein Mar 4, 2025
7a848bf
setting up test_PEM.R and adjustments to tests in pipelines
markusgoeswein Mar 4, 2025
2fc11a4
change lrn() from regr.xgboost to regr.glmnet
markusgoeswein Mar 13, 2025
c035e9b
update DESCRIPTION
markusgoeswein Mar 13, 2025
cb52f0b
add glmnet to suggests for PEM pipeline tests
markusgoeswein Mar 13, 2025
4e509e4
included require_namespace('glmnet, ...) for PEM pipeline tests
markusgoeswein Mar 13, 2025
82b8af7
minor fix
markusgoeswein Mar 13, 2025
0532d71
set PEM to lowercase in PipeOpTask and PipeOpPred, update DESCRIPTION…
markusgoeswein Mar 25, 2025
ec0ffdd
temporary name change of PEM pipeops
markusgoeswein Mar 25, 2025
a5402c3
man files are renamed with lowercase pem
markusgoeswein Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update function doc
  • Loading branch information
markusgoeswein committed Feb 25, 2025
commit a3bdbf508952a1d96d48b35902c3b2819fb807d7
26 changes: 20 additions & 6 deletions R/PipeOpPredRegrSurvPEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@
#' @name mlr_pipeops_trafopred_regrsurv_PEM
#'
#' @description
#' Transform [PredictionRegr] to [PredictionSurv].
#'
#' Transform [PredictionRegr] to [PredictionSurv].
#' Predicted hazards are transformed into survival probabilities and wrapped in a
#' [PredictionSurv] object.
#'
#' Continuous time is partitioned into time intervals \eqn{[0, t_1), [t_1, t_2), ..., [t_J, \infty)}.
#' [PredictionRegr] contains the estimates of the piece-wise constant hazards defined as
#' \deqn{\lambda(t \mid \mathbf{x}_i (t)) := exp(g(x_{ij},t{j})), \quad \forall t \in [t_{j-1}, t_{j}), \quad i = 1, \dots, n.}
#'
#' Via the following identity
#' \deqn{S(t | \mathbf{x}) = \exp \left( - \int_{0}^{t} \lambda(s | \mathbf{x}) \, ds \right) = \exp \left( - \sum_{j = 1}^{J} \lambda(j | \mathbf{x}) d_j\, \right),}
#' where \eqn{d_j} specifies the duration of interval \eqn{j},
#'
#' we compute the survival probability from the predicted hazards.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent! I would suggest 1) remove the time-dependency as we don't support it, ie x instead of x(t) 2) describe a bit g function? 3) add reference via the bibtex file => Andreas 2018 paper (A generalized additive model approach to time-to-event analysis)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some additional descriptions for clarification. At the same time, I am wondering whether PipeOpPred is the appropriate place for these mathematical explanations. I feel, as a whole, they might be more appropriate in the doc of the pipeline, with exception of the backtransform portion, i.e. how we get surv probs from hazards.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point - it makes sense to put the related doc directly in the class that implementes what the docs says - and provide the links to that in the pipeline, but up to you if you want to add extra math doc somewhere (even a bit duplicated), always welcome!

#'
#'
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
Expand Down Expand Up @@ -51,20 +65,20 @@ PipeOpPredRegrSurvPEM = R6Class(

private = list(
.predict = function(input) {
pred = input[[1]]
pred = input[[1]] # retrieve the hazards predicted by the regression learner
data = input[[2]]
assert_true(!is.null(pred$response))
# probability of having the event (1) in each respective interval
# is the discrete-time hazard


data = cbind(data, dt_hazard = pred$response)

# From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset)))
rows_per_id = nrow(data) / length(unique(data$id))

# If 'single_event', 'cr', 'msm')
surv = t(vapply(unique(data$id), function(unique_id) {
exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])))
}, numeric(rows_per_id)))


unique_end_times = sort(unique(data$tend))
# coerce to distribution and crank
Expand Down
53 changes: 30 additions & 23 deletions R/PipeOpTaskSurvRegrPEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#'
#' @description
#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous
#' time into multiple time intervals for each observation.
#' time into multiple time intervals for each observation. The survival data set
#' stored in [TaskSurv] is transformed into Piece-wise Exponential Data (PED) format
#' which in turn forms the backend for [TaskRegr][mlr3::TaskRegr].
#' This transformation creates a new target variable `PEM_status` that indicates
#' whether an event occurred within each time interval.
#'
Expand All @@ -31,8 +33,8 @@
#' The "transformed_data" is an empty [data.table][data.table::data.table].
#'
#' During prediction, the "input" [TaskSurv] is transformed to the "output"
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target and the `"tend"`
#' as well as `"offset"` feature included.
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target, while `"tend"`
#' and `"offset"` are included as features.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mroe accurately: offset in not a feature, ie it doesn;'t havbe the col_role feature, but the offset one

#' The "transformed_data" is a [data.table] with columns the `"PEM_status"`
#' target of the "output" task, the `"id"` (original observation ids),
#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval).
Expand All @@ -53,12 +55,10 @@
#' If `cut` is unspecified, this will be the last possible event time.
#' All event times after `max_time` will be administratively censored at `max_time.`
#' Needs to be greater than the minimum event time in the given task.
#'
#' @examples
#'
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE)
#' * `ped_formula`
#' TODO
#' @examplesIf (mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3extralearners"), quietly = TRUE))
#' \dontrun{
#' # Update documentation to match PEM
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3pipelines)
Expand All @@ -73,11 +73,10 @@
#' # the end time points of the discrete time intervals
#' unique(task_regr$data(cols = "tend"))[[1L]]
#'
#' # train a classification learner
#' learner = lrn("classif.log_reg", predict_type = "prob")
#' # train a regression learner
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... that supports poisson regression....

#' learner = lrn("regr.gam") # won't run unless learner can accept offset column role
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: when I finish the mlr3extralearners PR, we can safely remove this comment here.

Also correct the example + make it a bit more interesting, e.g. => l = lrn("regr.gam", formula = pem_status ~ s(age) + s(tend), family = "poisson") => you definitely need the family poisson argument here

#' learner$train(task_regr)
#' }
#' }
#' }
#'
#'
#' @family PipeOps
Expand All @@ -95,7 +94,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
max_time = p_dbl(0, default = NULL, special_vals = list(NULL)),
censor_code = p_int(0L),
min_events = p_int(1L),
form = p_uty(tags = 'train')
ped_formula = p_uty(tags = 'train', default = NULL)
#pammtools arguments: transitions etc.
)
super$initialize(
Expand Down Expand Up @@ -139,16 +138,16 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
"max_time must be greater than the minimum event time.")
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing redundant empty lines in all code would be nice - some space is good, more space is unnecessary

# To-Do: Extend to a more general formulation for competing risks and msm
# Issue: We pass form (e.g. Surv(time, status) ~ .) which currently serves to correctly transform the data into ped format
# but doesn't serve any other purpose yet. For ML learners, such as xgb, the covariate structure is passed to the pipeline via rhs not form.
long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time)
self$state$cut = attributes(long_data)$trafo_args$cut
ped_formula = self$param_set$values$ped_formula
if (is.null(ped_formula)){
ped_formula = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")
}
long_data = pammtools::as_ped(data = data, formula = ped_formula, cut = cut, max_time = max_time)
long_data = as.data.table(long_data)


self$state$cut = attributes(long_data)$trafo_args$cut

long_data = as.data.table(long_data)
setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM
setnames(long_data, old = "ped_status", new = "PEM_status")

# remove some columns from `long_data`
long_data[, c("tstart", "interval") := NULL]
Expand All @@ -161,6 +160,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data,
target = "PEM_status")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit more proper indentation style => target should be below new( <= here, please check all code for this

task_PEM$set_col_roles("id", roles = "original_ids")
task_PEM$set_col_roles('offset', roles = "offset")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: no ' anywhere in the code please, use only "!


list(task_PEM, data.table())
},
Expand All @@ -181,9 +181,15 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",

status = data[[event_var]]
data[[event_var]] = 1


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of extra space: a good informative comment!

Copy link
Collaborator

@markusgoeswein markusgoeswein Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly, not quite sure what's the purpose of data[[event_var]] = 1
As for data[[time_var]] = max_time, this ensures that for each subject the ped data spans over all intervals instead of only until the event time, which of course is sensible for prediction. I added this as a comment.

ped_formula = self$param_set$values$ped_formula
if (is.null(ped_formula)){
ped_formula = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")
}
long_data = pammtools::as_ped(data = data, formula = ped_formula, cut = cut, max_time = max_time)
long_data = as.data.table(long_data)


long_data = as.data.table(pammtools::as_ped(data, formula = self$param_set$values$form, cut = cut))
setnames(long_data, old = "ped_status", new = "PEM_status")

PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table
Expand All @@ -204,6 +210,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data,
target = "PEM_status")
task_PEM$set_col_roles("id", roles = "original_ids")
task_PEM$set_col_roles('offset', roles = "offset")

# map observed times back
reps = table(long_data$id)
Expand Down
Loading