diff --git a/R/validate.R b/R/validate.R index 13f84af74..dde7a5bbf 100644 --- a/R/validate.R +++ b/R/validate.R @@ -18,16 +18,84 @@ #' @keywords check-forecasts #' @examples #' as_forecast(example_binary) -#' as_forecast(example_quantile) -as_forecast <- function(data, ...) { +#' as_forecast( +#' example_quantile, +#' forecast_unit = c("model", "target_type", "target_end_date", +#' "horizon", "location") +#' ) +as_forecast <- function(data, + ...) { UseMethod("as_forecast") } #' @rdname as_forecast +#' @param observed (optional) Name of the column in `data` that contains the +#' observed values. This column will be renamed to "observed". +#' @param predicted (optional) Name of the column in `data` that contains the +#' predicted values. This column will be renamed to "predicted". +#' @param model (optional) Name of the column in `data` that contains the names +#' of the models/forecasters that generated the predicted values. +#' This column will be renamed to "model". +#' @param forecast_unit (optional) Name of the columns in `data` (after +#' renaming) that denote the unit of a single forecast. +#' See [get_forecast_unit()] for details. +#' @param quantile_level (optional) Name of the column in `data` that contains +#' the quantile level of the predicted values. This column will be renamed to +#' "quantile_level". Only applicable to quantile-based forecasts. +#' @param sample_id (optional) Name of the column in `data` that contains the +#' sample id. This column will be renamed to "sample_id". Only applicable to +#' sample-based forecasts. #' @export -as_forecast.default <- function(data, ...) { +as_forecast.default <- function(data, + observed = NULL, + predicted = NULL, + model = NULL, + forecast_unit = NULL, + quantile_level = NULL, + sample_id = NULL, + ...) { + # check inputs + data <- ensure_data.table(data) + assert_character(observed, len = 1, null.ok = TRUE) + assert_subset(observed, names(data), empty.ok = TRUE) + + assert_character(predicted, len = 1, null.ok = TRUE) + assert_subset(predicted, names(data), empty.ok = TRUE) + + assert_character(model, len = 1, null.ok = TRUE) + assert_subset(model, names(data), empty.ok = TRUE) + + assert_character(quantile_level, len = 1, null.ok = TRUE) + assert_subset(quantile_level, names(data), empty.ok = TRUE) + + assert_character(sample_id, len = 1, null.ok = TRUE) + assert_subset(sample_id, names(data), empty.ok = TRUE) + + # rename columns + if (!is.null(observed)) { + setnames(data, old = observed, new = "observed") + } + if (!is.null(predicted)) { + setnames(data, old = predicted, new = "predicted") + } + if (!is.null(model)) { + setnames(data, old = model, new = "model") + } + if (!is.null(quantile_level)) { + setnames(data, old = quantile_level, new = "quantile_level") + } + if (!is.null(sample_id)) { + setnames(data, old = sample_id, new = "sample_id") + } + + # assert that everything worked out assert(check_data_columns(data)) + # set forecast unit (error handling is done in `set_forecast_unit()`) + if (!is.null(forecast_unit)) { + data <- set_forecast_unit(data, forecast_unit) + } + # find forecast type forecast_type <- get_forecast_type(data)