Skip to content

Commit

Permalink
Merge pull request #390 from epiforecasts/rework-add_coverage()
Browse files Browse the repository at this point in the history
Rework add coverage()
  • Loading branch information
nikosbosse authored Nov 16, 2023
2 parents ff54435 + 8fe6e2e commit eb45cbb
Show file tree
Hide file tree
Showing 86 changed files with 2,451 additions and 1,758 deletions.
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ export(bias_sample)
export(brier_score)
export(correlation)
export(crps_sample)
export(dispersion)
export(dss_sample)
export(get_duplicate_forecasts)
export(interval_coverage_deviation_quantile)
export(interval_coverage_quantile)
export(interval_coverage_sample)
export(interval_score)
export(log_shift)
export(logs_binary)
Expand All @@ -39,6 +43,7 @@ export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_scoringutils)
export(overprediction)
export(pairwise_comparison)
export(pit)
export(pit_sample)
Expand All @@ -54,6 +59,7 @@ export(plot_ranges)
export(plot_score_table)
export(plot_wis)
export(quantile_score)
export(run_safely)
export(sample_to_quantile)
export(score)
export(se_mean_sample)
Expand All @@ -63,6 +69,7 @@ export(summarise_scores)
export(summarize_scores)
export(theme_scoringutils)
export(transform_forecasts)
export(underprediction)
export(validate)
export(validate_general)
export(wis)
Expand All @@ -75,6 +82,8 @@ importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_data_table)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
Expand All @@ -100,6 +109,7 @@ importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setattr)
importFrom(data.table,setcolorder)
importFrom(data.table,setnames)
importFrom(ggdist,geom_lineribbon)
importFrom(ggplot2,.data)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The update introduces a lot of breaking changes. If you want to keep using the o
- `quantile`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- `check_forecasts()` was replaced by a new function `validate()`. `validate()` validates the input and in that sense fulfills the purpose of `check_forecasts()`. It has different methods: `validate.default()` assigns the input a class based on their forecast type. Other methods validate the input specifically for the various forecast types.
- The functionality for computing pairwise comparisons was now split from `summarise_scores()`. Instead of doing pairwise comparisons as part of summarising scores, a new function, `add_pairwise_comparison()`, was introduced that takes summarised scores as an input and adds pairwise comparisons to it.
- `add_coverage()` was reworked completely. It's new purpose is now to add coverage information to the raw forecast data (essentially fulfilling some of the functionality that was previously covered by `score_quantile()`)
- The function `find_duplicates()` was renamed to `get_duplicate_forecasts()`
- Changes to `avail_forecasts()` and `plot_avail_forecasts()`:
- The function `avail_forecasts()` was renamed to `available_forecasts()` for consistency with `available_metrics()`. The old function, `avail_forecasts()` is still available as an alias, but will be removed in the future.
Expand Down
83 changes: 83 additions & 0 deletions R/add_coverage.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#' @title Add Coverage Values to Quantile-Based Forecasts
#'
#' @description Adds interval coverage of central prediction intervals,
#' quantile coverage for predictive quantiles, as well as the deviation between
#' desired and actual coverage to a data.table. Forecasts should be in a
#' quantile format (following the input requirements of `score()`).
#'
#' **Interval coverage**
#'
#' Coverage for a given interval range is defined as the proportion of
#' observations that fall within the corresponding central prediction intervals.
#' Central prediction intervals are symmetric around the median and and formed
#' by two quantiles that denote the lower and upper bound. For example, the 50%
#' central prediction interval is the interval between the 0.25 and 0.75
#' quantiles of the predictive distribution.
#'
#' The function `add_coverage()` computes the coverage per central prediction
#' interval, so the coverage will always be either `TRUE` (observed value falls
#' within the interval) or `FALSE` (observed value falls outside the interval).
#' You can summarise the coverage values to get the proportion of observations
#' that fall within the central prediction intervals.
#'
#' **Quantile coverage**
#'
#' Quantile coverage for a given quantile is defined as the proportion of
#' observed values that are smaller than the corresponding predictive quantile.
#' For example, the 0.5 quantile coverage is the proportion of observed values
#' that are smaller than the 0.5 quantile of the predictive distribution.
#'
#' **Coverage deviation**
#'
#' The coverage deviation is the difference between the desired coverage and the
#' actual coverage. For example, if the desired coverage is 90% and the actual
#' coverage is 80%, the coverage deviation is -0.1.
#'
#' @inheritParams score
#' @return a data.table with the input and columns "interval_coverage",
#' "interval_coverage_deviation", "quantile_coverage",
#' "quantile_coverage_deviation" added.
#' @importFrom data.table setcolorder
#' @examples
#' library(magrittr) # pipe operator
#' example_quantile %>%
#' add_coverage()
#' @export
#' @keywords scoring
#' @export
add_coverage <- function(data) {
stored_attributes <- get_scoringutils_attributes(data)
data <- validate(data)
forecast_unit <- get_forecast_unit(data)
data_cols <- colnames(data) # store so we can reset column order later

# what happens if quantiles are not symmetric around the median?
# should things error? Also write tests for that.
interval_data <- quantile_to_interval(data, format = "wide")
interval_data[, interval_coverage := ifelse(
observed <= upper & observed >= lower,
TRUE,
FALSE)
][, c("lower", "upper", "observed") := NULL]

data[, range := get_range_from_quantile(quantile)]

data <- merge(interval_data, data, by = unique(c(forecast_unit, "range")))
data[, interval_coverage_deviation := interval_coverage - range / 100]
data[, quantile_coverage := observed <= predicted]
data[, quantile_coverage_deviation := quantile_coverage - quantile]

# reset column order
new_metrics <- c("interval_coverage", "interval_coverage_deviation",
"quantile_coverage", "quantile_coverage_deviation")
setcolorder(data, unique(c(data_cols, "range", new_metrics)))

# add coverage "metrics" to list of stored metrics
# this makes it possible to use `summarise_scores()` later on
stored_attributes[["metric_names"]] <- c(
stored_attributes[["metric_names"]],
new_metrics
)
data <- assign_attributes(data, stored_attributes)
return(data[])
}
14 changes: 12 additions & 2 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,22 @@ check_columns_present <- function(data, columns) {
}
assert_character(columns, min.len = 1)
colnames <- colnames(data)
missing <- list()
for (x in columns){
if (!(x %in% colnames)) {
msg <- paste0("Column '", x, "' not found in data")
return(msg)
missing[[x]] <- x
}
}
missing <- unlist(missing)
if (length(missing) > 1) {
msg <- paste0(
"Columns '", paste(missing, collapse = "', '"), "' not found in data"
)
return(msg)
} else if (length(missing) == 1) {
msg <- paste0("Column '", missing, "' not found in data")
return(msg)
}
return(TRUE)
}

Expand Down
18 changes: 5 additions & 13 deletions R/convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,13 @@ log_shift <- function(x, offset = 0, base = exp(1)) {
#' example_quantile,
#' c("location", "target_end_date", "target_type", "horizon", "model")
#' )

set_forecast_unit <- function(data, forecast_unit) {

datacols <- colnames(data)
missing <- forecast_unit[!(forecast_unit %in% datacols)]

if (length(missing) > 0) {
warning(
"Column(s) '",
missing,
"' are not columns of the data and will be ignored."
)
forecast_unit <- intersect(forecast_unit, datacols)
data <- ensure_data.table(data)
missing <- check_columns_present(data, forecast_unit)
if (!is.logical(missing)) {
warning(missing)
forecast_unit <- intersect(forecast_unit, colnames(data))
}

keep_cols <- c(get_protected_columns(data), forecast_unit)
out <- unique(data[, .SD, .SDcols = keep_cols])[]
return(out)
Expand Down
27 changes: 21 additions & 6 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_quantile"


Expand All @@ -44,7 +44,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_point"


Expand All @@ -69,7 +69,7 @@
#' \item{predicted}{predicted value}
#' \item{sample_id}{id for the corresponding sample}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_continuous"


Expand Down Expand Up @@ -124,7 +124,7 @@
#' \item{horizon}{forecast horizon in weeks}
#' \item{predicted}{predicted value}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_binary"


Expand All @@ -147,7 +147,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_quantile_forecasts_only"


Expand All @@ -167,7 +167,7 @@
#' \item{observed}{observed values}
#' \item{location_name}{name of the country for which a prediction was made}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_truth_only"

#' Summary information for selected metrics
Expand Down Expand Up @@ -211,3 +211,18 @@
#' - "se_mean" = [se_mean_sample()]
#' @keywords info
"metrics_sample"

#' Default metrics for quantile-based forecasts.
#'
#' A named list with functions:
#' - "wis" = [wis()]
#' - "overprediction" = [overprediction()]
#' - "underprediction" = [underprediction()]
#' - "dispersion" = [dispersion()]
#' - "bias" = [bias_quantile()]
#' - "coverage_50" = \(...) {run_safely(..., range = 50, fun = [interval_coverage_quantile][interval_coverage_quantile()])}
#' - "coverage_90" = \(...) {run_safely(..., range = 90, fun = [interval_coverage_quantile][interval_coverage_quantile()])}
#' - "coverage_deviation" = [interval_coverage_deviation_quantile()],
#' - "ae_median" = [ae_median_quantile()]
#' @keywords info
"metrics_quantile"
2 changes: 2 additions & 0 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ get_protected_columns <- function(data = NULL) {
protected_columns <- c(
"predicted", "observed", "sample_id", "quantile", "upper", "lower",
"pit_value", "range", "boundary", "relative_skill", "scaled_rel_skill",
"interval_coverage", "interval_coverage_deviation",
"quantile_coverage", "quantile_coverage_deviation",
available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
Expand Down
Loading

0 comments on commit eb45cbb

Please sign in to comment.