Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

908: Flexible model column #915

Merged
merged 21 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,6 @@ check_try <- function(expr) {
return(msg)
}


#' @title Assure that data has a `model` column
#'
#' @description
#' Check whether the data.table has a column called `model`.
#' If not, a column called `model` is added with the value `Unspecified model`.
#' @inheritParams as_forecast
#' @importFrom cli cli_inform
#' @importFrom checkmate assert_data_table
#' @return The data.table with a column called `model`
#' @keywords internal_input_check
ensure_model_column <- function(data) {
assert_data_table(data)
if (!("model" %in% colnames(data))) {
#nolint start: keyword_quote_linter
cli_warn(
c(
"!" = "There is no column called `model` in the data.",
"i" = "scoringutils assumes that all forecasts come from the
same model"
)
)
#nolint end
data[, model := "Unspecified model"]
}
return(data[])
}


#' Check that all forecasts have the same number of quantiles or samples
#' @description
#' Function checks the number of quantiles or samples per forecast.
Expand Down
40 changes: 10 additions & 30 deletions R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#'
#' The `as_forecast_<type>()` functions give users some control over how their
#' data is parsed.
#' Using the arguments `observed`, `predicted`, `model`, etc. users can rename
#' Using the arguments `observed`, `predicted`, etc. users can rename
#' existing columns of their input data to match the required columns for a
#' forecast object. Using the argument `forecast_unit`, users can specify the
#' the columns that uniquely identify a single forecast (and remove the others,
Expand All @@ -37,9 +37,6 @@
#' observed values. This column will be renamed to "observed".
#' @param predicted (optional) Name of the column in `data` that contains the
#' predicted values. This column will be renamed to "predicted".
#' @param model (optional) Name of the column in `data` that contains the names
#' of the models/forecasters that generated the predicted values.
#' This column will be renamed to "model".
#' @inheritSection forecast_types Forecast types and input formats
#' @inheritSection forecast_types Forecast unit
#' @return
Expand Down Expand Up @@ -72,8 +69,7 @@ NULL
as_forecast_generic <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL) {
predicted = NULL) {
# check inputs - general
data <- ensure_data.table(data)
assert_character(observed, len = 1, null.ok = TRUE)
Expand All @@ -82,22 +78,13 @@ as_forecast_generic <- function(data,
assert_character(predicted, len = 1, null.ok = TRUE)
assert_subset(predicted, names(data), empty.ok = TRUE)

assert_character(model, len = 1, null.ok = TRUE)
assert_subset(model, names(data), empty.ok = TRUE)

# rename columns - general
if (!is.null(observed)) {
setnames(data, old = observed, new = "observed")
}
if (!is.null(predicted)) {
setnames(data, old = predicted, new = "predicted")
}
if (!is.null(model)) {
setnames(data, old = model, new = "model")
}

# ensure that a model column is present after renaming
ensure_model_column(data)

# set forecast unit (error handling is done in `set_forecast_unit()`)
if (!is.null(forecast_unit)) {
Expand All @@ -119,9 +106,8 @@ as_forecast_generic <- function(data,
as_forecast_binary <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
predicted = NULL) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_binary")
assert_forecast(data)
return(data)
Expand Down Expand Up @@ -149,9 +135,8 @@ as_forecast_point.default <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
...) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_point")
assert_forecast(data)
return(data)
Expand Down Expand Up @@ -206,7 +191,6 @@ as_forecast_quantile.default <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
quantile_level = NULL,
...) {
assert_character(quantile_level, len = 1, null.ok = TRUE)
Expand All @@ -215,7 +199,7 @@ as_forecast_quantile.default <- function(data,
setnames(data, old = quantile_level, new = "quantile_level")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_quantile")
assert_forecast(data)
return(data)
Expand Down Expand Up @@ -280,15 +264,14 @@ as_forecast_sample <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
sample_id = NULL) {
assert_character(sample_id, len = 1, null.ok = TRUE)
assert_subset(sample_id, names(data), empty.ok = TRUE)
if (!is.null(sample_id)) {
setnames(data, old = sample_id, new = "sample_id")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_sample")
assert_forecast(data)
return(data)
Expand All @@ -312,15 +295,14 @@ as_forecast_nominal <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_nominal")
assert_forecast(data)
return(data)
Expand Down Expand Up @@ -523,7 +505,7 @@ validate_forecast <- function(forecast, forecast_type = NULL, verbose = TRUE) {
#' The function runs input checks that apply to all input data, regardless of
#' forecast type. The function
#' - asserts that the forecast is a data.table which has columns `observed` and
#' `predicted`, as well as a column called `model`.
#' `predicted`
#' - checks the forecast type and forecast unit
#' - checks there are no duplicate forecasts
#' - if appropriate, checks the number of samples / quantiles is the same
Expand All @@ -539,7 +521,7 @@ validate_forecast <- function(forecast, forecast_type = NULL, verbose = TRUE) {
assert_forecast_generic <- function(data, verbose = TRUE) {
# check that data is a data.table and that the columns look fine
assert_data_table(data, min.rows = 1)
assert(check_columns_present(data, c("observed", "predicted", "model")))
assert(check_columns_present(data, c("observed", "predicted")))
problem <- test_columns_present(data, c("sample_id", "quantile_level"))
if (problem) {
cli_abort(
Expand Down Expand Up @@ -616,7 +598,6 @@ clean_forecast <- function(forecast, copy = FALSE, na.omit = FALSE) {
#' @description
#' Construct a class based on a data.frame or similar. The constructor
#' - coerces the data into a data.table
#' - makes sure that a column called `model` exists and if not creates one
#' - assigns a class
#'
#' @inheritParams as_forecast
Expand All @@ -626,7 +607,6 @@ clean_forecast <- function(forecast, copy = FALSE, na.omit = FALSE) {
#' @keywords internal
new_forecast <- function(data, classname) {
data <- as.data.table(data)
data <- ensure_model_column(data)
class(data) <- c(classname, "forecast", class(data))
data <- copy(data)
return(data[])
Expand Down
Loading
Loading