diff --git a/NEWS.md b/NEWS.md index e77c172c7..05eae9c49 100644 --- a/NEWS.md +++ b/NEWS.md @@ -35,7 +35,9 @@ The update introduces breaking changes. If you want to keep using the older vers - `plot_avail_forecasts()` was renamed `plot_forecast_counts()` in line with the change in the function name. The `x` argument no longer has a default value, as the value will depend on the data provided by the user. - The deprecated `..density..` was replaced with `after_stat(density)` in ggplot calls. - Files ending in ".Rda" were renamed to ".rds" where appropriate when used together with `saveRDS()` or `readRDS()`. -- added documentation for the return value of `summarise_scores()`. +- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast. +- added documentation for the return value of `summarise_scores()`. + # scoringutils 1.2.1 diff --git a/R/available_forecasts.R b/R/available_forecasts.R index 40eda6080..db4d40c88 100644 --- a/R/available_forecasts.R +++ b/R/available_forecasts.R @@ -40,7 +40,7 @@ get_forecast_counts <- function(data, data <- as_forecast(data) forecast_unit <- attr(data, "forecast_unit") - data <- remove_na_observed_predicted(data) + data <- na.omit(data) if (is.null(by)) { by <- forecast_unit diff --git a/R/pit.R b/R/pit.R index 067db5123..fd8f4819d 100644 --- a/R/pit.R +++ b/R/pit.R @@ -186,7 +186,7 @@ pit <- function(data, n_replicates = 100) { data <- as_forecast(data) - data <- remove_na_observed_predicted(data) + data <- na.omit(data) forecast_type <- get_forecast_type(data) if (forecast_type == "quantile") { diff --git a/R/score.R b/R/score.R index 5ca049fa7..94d81fb16 100644 --- a/R/score.R +++ b/R/score.R @@ -29,6 +29,7 @@ #' unsummarised scores to obtain one score per forecast unit for quantile-based #' forecasts. #' @importFrom data.table ':=' as.data.table +#' @importFrom stats na.omit #' @examples #' library(magrittr) # pipe operator #' data.table::setDTthreads(1) # only needed to avoid issues on CRAN @@ -74,11 +75,12 @@ score.default <- function(data, ...) { score(data, ...) } +#' @importFrom stats na.omit #' @rdname score #' @export score.forecast_binary <- function(data, metrics = metrics_binary, ...) { data <- validate_forecast(data) - data <- remove_na_observed_predicted(data) + data <- na.omit(data) metrics <- validate_metrics(metrics) data <- apply_rules( @@ -94,11 +96,12 @@ score.forecast_binary <- function(data, metrics = metrics_binary, ...) { #' @importFrom Metrics se ae ape +#' @importFrom stats na.omit #' @rdname score #' @export score.forecast_point <- function(data, metrics = metrics_point, ...) { data <- validate_forecast(data) - data <- remove_na_observed_predicted(data) + data <- na.omit(data) metrics <- validate_metrics(metrics) data <- apply_rules( @@ -111,11 +114,12 @@ score.forecast_point <- function(data, metrics = metrics_point, ...) { return(data[]) } +#' @importFrom stats na.omit #' @rdname score #' @export score.forecast_sample <- function(data, metrics = metrics_sample, ...) { data <- validate_forecast(data) - data <- remove_na_observed_predicted(data) + data <- na.omit(data) forecast_unit <- attr(data, "forecast_unit") metrics <- validate_metrics(metrics) @@ -147,12 +151,13 @@ score.forecast_sample <- function(data, metrics = metrics_sample, ...) { return(data[]) } +#' @importFrom stats na.omit #' @importFrom data.table `:=` as.data.table rbindlist %like% #' @rdname score #' @export score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) { data <- validate_forecast(data) - data <- remove_na_observed_predicted(data) + data <- na.omit(data) forecast_unit <- attr(data, "forecast_unit") metrics <- validate_metrics(metrics) @@ -163,8 +168,7 @@ score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) { observed = unique(observed), quantile = list(sort(quantile, na.last = TRUE)), scoringutils_quantile = toString(sort(quantile, na.last = TRUE)) - ), - by = forecast_unit] + ), by = forecast_unit] # split according to quantile lengths and do calculations for different # quantile lengths separately. The function `wis()` assumes that all diff --git a/R/summarise_scores.R b/R/summarise_scores.R index 5bda00557..ac01e017f 100644 --- a/R/summarise_scores.R +++ b/R/summarise_scores.R @@ -35,7 +35,6 @@ #' summarise_scores(scores) #' } #' -#' #' # summarise over samples or quantiles to get one score per forecast #' scores <- score(example_quantile) #' summarise_scores(scores) diff --git a/R/utils.R b/R/utils.R index 8e5e8a9e1..b84f4f71b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -3,23 +3,12 @@ #' @return A vector with the name of all available metrics #' @export #' @keywords info - available_metrics <- function() { return(unique(c(scoringutils::metrics$Name, "wis", "coverage_50", "coverage_90"))) } -remove_na_observed_predicted <- function(data) { - # remove rows where predicted or observed value are NA ----------------------- - data <- data[!is.na(observed) & !is.na(predicted)] - if (nrow(data) == 0) { - stop("After removing NA values in `observed` and `predicted`, there were no observations left") - } - return(data[]) -} - - #' @title Collapse several messages to one #' #' @description Internal helper function to facilitate generating messages @@ -38,7 +27,6 @@ collapse_messages <- function(type = "messages", messages) { } - #' @title Print output from `check_forecasts()` #' #' @description Helper function that prints the output generated by diff --git a/R/validate.R b/R/validate.R index 4a1be5f1f..faab1fbc6 100644 --- a/R/validate.R +++ b/R/validate.R @@ -155,14 +155,17 @@ validate_general <- function(data) { setattr(data, "warnings", number_quantiles_samples) } - # check whether there are any NA values in the predicted or observed values - messages <- c( - check_no_NA_present(data, "predicted"), - check_no_NA_present(data, "observed") - ) - if (!is.logical(messages)) { - messages <- messages[messages != "TRUE"] - setattr(data, "messages", messages) + # check whether there are any NA values + if (anyNA(data)) { + if (nrow(na.omit(data)) == 0) { + stop( + "After removing rows with NA values in the data, no forecasts are left." + ) + } + message( + "Some rows containing NA values may be removed. ", + "This is fine if not unexpected." + ) } return(data[]) diff --git a/man/summarise_scores.Rd b/man/summarise_scores.Rd index 3c461bfd0..9d6ce674a 100644 --- a/man/summarise_scores.Rd +++ b/man/summarise_scores.Rd @@ -52,7 +52,6 @@ scores <- score(example_continuous) summarise_scores(scores) } - # summarise over samples or quantiles to get one score per forecast scores <- score(example_quantile) summarise_scores(scores) diff --git a/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg b/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg index 91848b1dd..548878c34 100644 --- a/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg +++ b/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg @@ -57,17 +57,15 @@ 100 Nominal interval coverage % Obs inside interval -model - - - - - -EuroCOVIDhub-baseline -EuroCOVIDhub-ensemble -UMass-MechBayes -epiforecasts-EpiNow2 -NA +model + + + + +EuroCOVIDhub-baseline +EuroCOVIDhub-ensemble +UMass-MechBayes +epiforecasts-EpiNow2 plot_interval_coverage diff --git a/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg b/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg index 76808cc67..bf686eedb 100644 --- a/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg +++ b/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg @@ -57,17 +57,15 @@ 1.00 Quantile % Obs below quantile -model - - - - - -EuroCOVIDhub-baseline -EuroCOVIDhub-ensemble -UMass-MechBayes -epiforecasts-EpiNow2 -NA +model + + + + +EuroCOVIDhub-baseline +EuroCOVIDhub-ensemble +UMass-MechBayes +epiforecasts-EpiNow2 plot_quantile_coverage diff --git a/tests/testthat/test-add_coverage.R b/tests/testthat/test-add_coverage.R index 689b8640b..32b722a8a 100644 --- a/tests/testthat/test-add_coverage.R +++ b/tests/testthat/test-add_coverage.R @@ -1,7 +1,10 @@ ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"] test_that("add_coverage() works as expected", { - expect_no_condition(cov <- add_coverage(example_quantile)) + expect_message( + cov <- add_coverage(example_quantile), + "Some rows containing NA values may be removed." + ) required_names <- c( "range", "interval_coverage", "interval_coverage_deviation", diff --git a/tests/testthat/test-available_forecasts.R b/tests/testthat/test-available_forecasts.R index ffa9e4f9e..4507820b1 100644 --- a/tests/testthat/test-available_forecasts.R +++ b/tests/testthat/test-available_forecasts.R @@ -8,26 +8,26 @@ test_that("get_forecast_counts() works as expected", { expect_type(af$target_type, "character") expect_type(af$`count`, "integer") expect_equal(nrow(af[is.na(`count`)]), 0) - af <- get_forecast_counts(example_quantile, by = "model") + af <- get_forecast_counts(na.omit(example_quantile), by = "model") expect_equal(nrow(af), 4) expect_equal(af$`count`, c(256, 256, 128, 247)) # Setting `collapse = c()` means that all quantiles and samples are counted af <- get_forecast_counts( - example_quantile, + na.omit(example_quantile), by = "model", collapse = c() ) expect_equal(nrow(af), 4) expect_equal(af$`count`, c(5888, 5888, 2944, 5681)) # setting by = NULL, the default, results in by equal to forecast unit - af <- get_forecast_counts(example_quantile) + af <- get_forecast_counts(na.omit(example_quantile)) expect_equal(nrow(af), 50688) # check whether collapsing also works for model-based forecasts - af <- get_forecast_counts(example_integer, by = "model") + af <- get_forecast_counts(na.omit(example_integer), by = "model") expect_equal(nrow(af), 4) - af <- get_forecast_counts(example_integer, by = "model", collapse = c()) + af <- get_forecast_counts(na.omit(example_integer), by = "model", collapse = c()) expect_equal(af$count, c(10240, 10240, 5120, 9880)) }) diff --git a/tests/testthat/test-check_forecasts.R b/tests/testthat/test-check_forecasts.R index 004031d71..447596768 100644 --- a/tests/testthat/test-check_forecasts.R +++ b/tests/testthat/test-check_forecasts.R @@ -96,7 +96,7 @@ test_that("as_forecast() function throws an error when no predictions or observe test_that("output of check_forecasts() is accepted as input to score()", { check <- suppressMessages(as_forecast(example_binary)) expect_no_error( - score_check <- score(check) + score_check <- score(na.omit(check)) ) expect_equal(score_check, suppressMessages(score(example_binary))) }) diff --git a/tests/testthat/test-convenience-functions.R b/tests/testthat/test-convenience-functions.R index 98d784cd4..d5f83aeb9 100644 --- a/tests/testthat/test-convenience-functions.R +++ b/tests/testthat/test-convenience-functions.R @@ -51,7 +51,7 @@ test_that("function set_forecast_unit() works", { example_quantile, c("location", "target_end_date", "target_type", "horizon", "model") ) - scores2 <- score(ex2) + scores2 <- score(na.omit(ex2)) scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ] expect_equal(scores1$interval_score, scores2$interval_score) diff --git a/tests/testthat/test-get_-functions.R b/tests/testthat/test-get_-functions.R index 5368270a1..f695bb936 100644 --- a/tests/testthat/test-get_-functions.R +++ b/tests/testthat/test-get_-functions.R @@ -1,7 +1,6 @@ # ============================================================================== # `get_forecast_unit()` # ============================================================================== - test_that("get_forecast_unit() works as expected", { expect_equal( get_forecast_unit(example_quantile), @@ -15,7 +14,7 @@ test_that("get_forecast_unit() works as expected", { "forecast_date", "model", "horizon") ) - data <- as_forecast(example_quantile) + data <- as_forecast(na.omit(example_quantile)) ex <- data[, location := NULL] expect_warning( get_forecast_unit(ex, check_conflict = TRUE), @@ -28,6 +27,36 @@ fixed = TRUE }) +# ============================================================================== +# Test removing `NA` values from the data +# ============================================================================== +test_that("removing NA rows from data works as expected", { + expect_equal(nrow(na.omit(example_quantile)), 20401) + + ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4) + expect_equal(nrow(na.omit(ex)), 3) + + ex$predicted <- c(1:3, NA) + expect_equal(nrow(na.omit(ex)), 2) + + # test that attributes and classes are retained + ex <- as_forecast(na.omit(example_integer)) + expect_equal( + class(na.omit(ex)), + c("forecast_sample", "data.table", "data.frame") + ) + + attributes <- get_scoringutils_attributes(ex) + expect_equal( + get_scoringutils_attributes(na.omit(ex)), + attributes + ) +}) + + +# ============================================================================== +# `get_type()` +# ============================================================================== test_that("get_type() works as expected with vectors", { expect_equal(get_type(1:3), "integer") expect_equal(get_type(factor(1:2)), "classification") @@ -164,7 +193,7 @@ test_that("get_forecast_type() works as expected", { fixed = TRUE ) - data <- as_forecast(example_integer) + data <- as_forecast(na.omit(example_integer)) attr(data, "forecast_type") <- "binary" expect_warning( get_forecast_type(data), diff --git a/tests/testthat/test-plot_avail_forecasts.R b/tests/testthat/test-plot_avail_forecasts.R index f131a6d7a..4240552cb 100644 --- a/tests/testthat/test-plot_avail_forecasts.R +++ b/tests/testthat/test-plot_avail_forecasts.R @@ -1,6 +1,6 @@ test_that("plot.forecast_counts() works as expected", { available_forecasts <- get_forecast_counts( - example_quantile, + na.omit(example_quantile), by = c("model", "target_type", "target_end_date") ) p <- plot_forecast_counts(available_forecasts, diff --git a/tests/testthat/test-plot_interval_coverage.R b/tests/testthat/test-plot_interval_coverage.R index 5ff1fbf98..f8f3461ad 100644 --- a/tests/testthat/test-plot_interval_coverage.R +++ b/tests/testthat/test-plot_interval_coverage.R @@ -1,5 +1,5 @@ test_that("plot_interval_coverage() works as expected", { - coverage <- add_coverage(example_quantile) %>% + coverage <- add_coverage(na.omit(example_quantile)) %>% summarise_scores(by = c("model", "range")) p <- plot_interval_coverage(coverage) expect_s3_class(p, "ggplot") diff --git a/tests/testthat/test-plot_quantile_coverage.R b/tests/testthat/test-plot_quantile_coverage.R index 1851ccb5c..9b44b8bf1 100644 --- a/tests/testthat/test-plot_quantile_coverage.R +++ b/tests/testthat/test-plot_quantile_coverage.R @@ -1,5 +1,5 @@ test_that("plot_quantile_coverage() works as expected", { - coverage <- add_coverage(example_quantile) %>% + coverage <- add_coverage(na.omit(example_quantile)) %>% summarise_scores(by = c("model", "quantile")) p <- plot_quantile_coverage(coverage) diff --git a/tests/testthat/test-plot_ranges.R b/tests/testthat/test-plot_ranges.R index 9a18cffe3..36c2ea41c 100644 --- a/tests/testthat/test-plot_ranges.R +++ b/tests/testthat/test-plot_ranges.R @@ -1,6 +1,7 @@ m <- modifyList(metrics_no_cov_no_ae, list("bias" = NULL)) sum_scores <- copy(example_quantile) %>% + na.omit() %>% .[, interval_range := scoringutils:::get_range_from_quantile(quantile)] %>% score(metrics = m) %>% summarise_scores(by = c("model", "target_type", "interval_range")) diff --git a/tests/testthat/test-score.R b/tests/testthat/test-score.R index ac6249528..29e9cf9c1 100644 --- a/tests/testthat/test-score.R +++ b/tests/testthat/test-score.R @@ -42,7 +42,7 @@ test_that("score.forecast_binary() errors with only NA values", { only_nas <- copy(example_binary)[, predicted := NA_real_] expect_error( score(only_nas), - "After removing NA values in `observed` and `predicted`, there were no observations left" + "After removing rows with NA values in the data, no forecasts are left." ) }) @@ -160,7 +160,7 @@ test_that("score.forecast_point() errors with only NA values", { only_nas <- copy(example_point)[, predicted := NA_real_] expect_error( score(only_nas), - "After removing NA values in `observed` and `predicted`, there were no observations left" + "After removing rows with NA values in the data, no forecasts are left." ) }) @@ -243,7 +243,7 @@ test_that("score.forecast_quantile() errors with only NA values", { only_nas <- copy(example_quantile)[, predicted := NA_real_] expect_error( score(only_nas), - "After removing NA values in `observed` and `predicted`, there were no observations left" + "After removing rows with NA values in the data, no forecasts are left." ) }) @@ -259,7 +259,7 @@ test_that("function produces output for a continuous format case", { only_nas <- copy(example_continuous)[, predicted := NA_real_] expect_error( score(only_nas), - "After removing NA values in `observed` and `predicted`, there were no observations left" + "After removing rows with NA values in the data, no forecasts are left." ) expect_equal(