diff --git a/NEWS.md b/NEWS.md
index e77c172c7..05eae9c49 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -35,7 +35,9 @@ The update introduces breaking changes. If you want to keep using the older vers
- `plot_avail_forecasts()` was renamed `plot_forecast_counts()` in line with the change in the function name. The `x` argument no longer has a default value, as the value will depend on the data provided by the user.
- The deprecated `..density..` was replaced with `after_stat(density)` in ggplot calls.
- Files ending in ".Rda" were renamed to ".rds" where appropriate when used together with `saveRDS()` or `readRDS()`.
-- added documentation for the return value of `summarise_scores()`.
+- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
+- added documentation for the return value of `summarise_scores()`.
+
# scoringutils 1.2.1
diff --git a/R/available_forecasts.R b/R/available_forecasts.R
index 40eda6080..db4d40c88 100644
--- a/R/available_forecasts.R
+++ b/R/available_forecasts.R
@@ -40,7 +40,7 @@ get_forecast_counts <- function(data,
data <- as_forecast(data)
forecast_unit <- attr(data, "forecast_unit")
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
if (is.null(by)) {
by <- forecast_unit
diff --git a/R/pit.R b/R/pit.R
index 067db5123..fd8f4819d 100644
--- a/R/pit.R
+++ b/R/pit.R
@@ -186,7 +186,7 @@ pit <- function(data,
n_replicates = 100) {
data <- as_forecast(data)
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
forecast_type <- get_forecast_type(data)
if (forecast_type == "quantile") {
diff --git a/R/score.R b/R/score.R
index 5ca049fa7..94d81fb16 100644
--- a/R/score.R
+++ b/R/score.R
@@ -29,6 +29,7 @@
#' unsummarised scores to obtain one score per forecast unit for quantile-based
#' forecasts.
#' @importFrom data.table ':=' as.data.table
+#' @importFrom stats na.omit
#' @examples
#' library(magrittr) # pipe operator
#' data.table::setDTthreads(1) # only needed to avoid issues on CRAN
@@ -74,11 +75,12 @@ score.default <- function(data, ...) {
score(data, ...)
}
+#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
data <- validate_forecast(data)
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
metrics <- validate_metrics(metrics)
data <- apply_rules(
@@ -94,11 +96,12 @@ score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
#' @importFrom Metrics se ae ape
+#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_point <- function(data, metrics = metrics_point, ...) {
data <- validate_forecast(data)
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
metrics <- validate_metrics(metrics)
data <- apply_rules(
@@ -111,11 +114,12 @@ score.forecast_point <- function(data, metrics = metrics_point, ...) {
return(data[])
}
+#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
data <- validate_forecast(data)
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)
@@ -147,12 +151,13 @@ score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
return(data[])
}
+#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table rbindlist %like%
#' @rdname score
#' @export
score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
data <- validate_forecast(data)
- data <- remove_na_observed_predicted(data)
+ data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)
@@ -163,8 +168,7 @@ score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
observed = unique(observed),
quantile = list(sort(quantile, na.last = TRUE)),
scoringutils_quantile = toString(sort(quantile, na.last = TRUE))
- ),
- by = forecast_unit]
+ ), by = forecast_unit]
# split according to quantile lengths and do calculations for different
# quantile lengths separately. The function `wis()` assumes that all
diff --git a/R/summarise_scores.R b/R/summarise_scores.R
index 5bda00557..ac01e017f 100644
--- a/R/summarise_scores.R
+++ b/R/summarise_scores.R
@@ -35,7 +35,6 @@
#' summarise_scores(scores)
#' }
#'
-#'
#' # summarise over samples or quantiles to get one score per forecast
#' scores <- score(example_quantile)
#' summarise_scores(scores)
diff --git a/R/utils.R b/R/utils.R
index 8e5e8a9e1..b84f4f71b 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -3,23 +3,12 @@
#' @return A vector with the name of all available metrics
#' @export
#' @keywords info
-
available_metrics <- function() {
return(unique(c(scoringutils::metrics$Name,
"wis", "coverage_50", "coverage_90")))
}
-remove_na_observed_predicted <- function(data) {
- # remove rows where predicted or observed value are NA -----------------------
- data <- data[!is.na(observed) & !is.na(predicted)]
- if (nrow(data) == 0) {
- stop("After removing NA values in `observed` and `predicted`, there were no observations left")
- }
- return(data[])
-}
-
-
#' @title Collapse several messages to one
#'
#' @description Internal helper function to facilitate generating messages
@@ -38,7 +27,6 @@ collapse_messages <- function(type = "messages", messages) {
}
-
#' @title Print output from `check_forecasts()`
#'
#' @description Helper function that prints the output generated by
diff --git a/R/validate.R b/R/validate.R
index 4a1be5f1f..faab1fbc6 100644
--- a/R/validate.R
+++ b/R/validate.R
@@ -155,14 +155,17 @@ validate_general <- function(data) {
setattr(data, "warnings", number_quantiles_samples)
}
- # check whether there are any NA values in the predicted or observed values
- messages <- c(
- check_no_NA_present(data, "predicted"),
- check_no_NA_present(data, "observed")
- )
- if (!is.logical(messages)) {
- messages <- messages[messages != "TRUE"]
- setattr(data, "messages", messages)
+ # check whether there are any NA values
+ if (anyNA(data)) {
+ if (nrow(na.omit(data)) == 0) {
+ stop(
+ "After removing rows with NA values in the data, no forecasts are left."
+ )
+ }
+ message(
+ "Some rows containing NA values may be removed. ",
+ "This is fine if not unexpected."
+ )
}
return(data[])
diff --git a/man/summarise_scores.Rd b/man/summarise_scores.Rd
index 3c461bfd0..9d6ce674a 100644
--- a/man/summarise_scores.Rd
+++ b/man/summarise_scores.Rd
@@ -52,7 +52,6 @@ scores <- score(example_continuous)
summarise_scores(scores)
}
-
# summarise over samples or quantiles to get one score per forecast
scores <- score(example_quantile)
summarise_scores(scores)
diff --git a/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg b/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg
index 91848b1dd..548878c34 100644
--- a/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg
+++ b/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg
@@ -57,17 +57,15 @@
100
Nominal interval coverage
% Obs inside interval
-model
-
-
-
-
-
-EuroCOVIDhub-baseline
-EuroCOVIDhub-ensemble
-UMass-MechBayes
-epiforecasts-EpiNow2
-NA
+model
+
+
+
+
+EuroCOVIDhub-baseline
+EuroCOVIDhub-ensemble
+UMass-MechBayes
+epiforecasts-EpiNow2
plot_interval_coverage
diff --git a/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg b/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg
index 76808cc67..bf686eedb 100644
--- a/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg
+++ b/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg
@@ -57,17 +57,15 @@
1.00
Quantile
% Obs below quantile
-model
-
-
-
-
-
-EuroCOVIDhub-baseline
-EuroCOVIDhub-ensemble
-UMass-MechBayes
-epiforecasts-EpiNow2
-NA
+model
+
+
+
+
+EuroCOVIDhub-baseline
+EuroCOVIDhub-ensemble
+UMass-MechBayes
+epiforecasts-EpiNow2
plot_quantile_coverage
diff --git a/tests/testthat/test-add_coverage.R b/tests/testthat/test-add_coverage.R
index 689b8640b..32b722a8a 100644
--- a/tests/testthat/test-add_coverage.R
+++ b/tests/testthat/test-add_coverage.R
@@ -1,7 +1,10 @@
ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"]
test_that("add_coverage() works as expected", {
- expect_no_condition(cov <- add_coverage(example_quantile))
+ expect_message(
+ cov <- add_coverage(example_quantile),
+ "Some rows containing NA values may be removed."
+ )
required_names <- c(
"range", "interval_coverage", "interval_coverage_deviation",
diff --git a/tests/testthat/test-available_forecasts.R b/tests/testthat/test-available_forecasts.R
index ffa9e4f9e..4507820b1 100644
--- a/tests/testthat/test-available_forecasts.R
+++ b/tests/testthat/test-available_forecasts.R
@@ -8,26 +8,26 @@ test_that("get_forecast_counts() works as expected", {
expect_type(af$target_type, "character")
expect_type(af$`count`, "integer")
expect_equal(nrow(af[is.na(`count`)]), 0)
- af <- get_forecast_counts(example_quantile, by = "model")
+ af <- get_forecast_counts(na.omit(example_quantile), by = "model")
expect_equal(nrow(af), 4)
expect_equal(af$`count`, c(256, 256, 128, 247))
# Setting `collapse = c()` means that all quantiles and samples are counted
af <- get_forecast_counts(
- example_quantile,
+ na.omit(example_quantile),
by = "model", collapse = c()
)
expect_equal(nrow(af), 4)
expect_equal(af$`count`, c(5888, 5888, 2944, 5681))
# setting by = NULL, the default, results in by equal to forecast unit
- af <- get_forecast_counts(example_quantile)
+ af <- get_forecast_counts(na.omit(example_quantile))
expect_equal(nrow(af), 50688)
# check whether collapsing also works for model-based forecasts
- af <- get_forecast_counts(example_integer, by = "model")
+ af <- get_forecast_counts(na.omit(example_integer), by = "model")
expect_equal(nrow(af), 4)
- af <- get_forecast_counts(example_integer, by = "model", collapse = c())
+ af <- get_forecast_counts(na.omit(example_integer), by = "model", collapse = c())
expect_equal(af$count, c(10240, 10240, 5120, 9880))
})
diff --git a/tests/testthat/test-check_forecasts.R b/tests/testthat/test-check_forecasts.R
index 004031d71..447596768 100644
--- a/tests/testthat/test-check_forecasts.R
+++ b/tests/testthat/test-check_forecasts.R
@@ -96,7 +96,7 @@ test_that("as_forecast() function throws an error when no predictions or observe
test_that("output of check_forecasts() is accepted as input to score()", {
check <- suppressMessages(as_forecast(example_binary))
expect_no_error(
- score_check <- score(check)
+ score_check <- score(na.omit(check))
)
expect_equal(score_check, suppressMessages(score(example_binary)))
})
diff --git a/tests/testthat/test-convenience-functions.R b/tests/testthat/test-convenience-functions.R
index 98d784cd4..d5f83aeb9 100644
--- a/tests/testthat/test-convenience-functions.R
+++ b/tests/testthat/test-convenience-functions.R
@@ -51,7 +51,7 @@ test_that("function set_forecast_unit() works", {
example_quantile,
c("location", "target_end_date", "target_type", "horizon", "model")
)
- scores2 <- score(ex2)
+ scores2 <- score(na.omit(ex2))
scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ]
expect_equal(scores1$interval_score, scores2$interval_score)
diff --git a/tests/testthat/test-get_-functions.R b/tests/testthat/test-get_-functions.R
index 5368270a1..f695bb936 100644
--- a/tests/testthat/test-get_-functions.R
+++ b/tests/testthat/test-get_-functions.R
@@ -1,7 +1,6 @@
# ==============================================================================
# `get_forecast_unit()`
# ==============================================================================
-
test_that("get_forecast_unit() works as expected", {
expect_equal(
get_forecast_unit(example_quantile),
@@ -15,7 +14,7 @@ test_that("get_forecast_unit() works as expected", {
"forecast_date", "model", "horizon")
)
- data <- as_forecast(example_quantile)
+ data <- as_forecast(na.omit(example_quantile))
ex <- data[, location := NULL]
expect_warning(
get_forecast_unit(ex, check_conflict = TRUE),
@@ -28,6 +27,36 @@ fixed = TRUE
})
+# ==============================================================================
+# Test removing `NA` values from the data
+# ==============================================================================
+test_that("removing NA rows from data works as expected", {
+ expect_equal(nrow(na.omit(example_quantile)), 20401)
+
+ ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4)
+ expect_equal(nrow(na.omit(ex)), 3)
+
+ ex$predicted <- c(1:3, NA)
+ expect_equal(nrow(na.omit(ex)), 2)
+
+ # test that attributes and classes are retained
+ ex <- as_forecast(na.omit(example_integer))
+ expect_equal(
+ class(na.omit(ex)),
+ c("forecast_sample", "data.table", "data.frame")
+ )
+
+ attributes <- get_scoringutils_attributes(ex)
+ expect_equal(
+ get_scoringutils_attributes(na.omit(ex)),
+ attributes
+ )
+})
+
+
+# ==============================================================================
+# `get_type()`
+# ==============================================================================
test_that("get_type() works as expected with vectors", {
expect_equal(get_type(1:3), "integer")
expect_equal(get_type(factor(1:2)), "classification")
@@ -164,7 +193,7 @@ test_that("get_forecast_type() works as expected", {
fixed = TRUE
)
- data <- as_forecast(example_integer)
+ data <- as_forecast(na.omit(example_integer))
attr(data, "forecast_type") <- "binary"
expect_warning(
get_forecast_type(data),
diff --git a/tests/testthat/test-plot_avail_forecasts.R b/tests/testthat/test-plot_avail_forecasts.R
index f131a6d7a..4240552cb 100644
--- a/tests/testthat/test-plot_avail_forecasts.R
+++ b/tests/testthat/test-plot_avail_forecasts.R
@@ -1,6 +1,6 @@
test_that("plot.forecast_counts() works as expected", {
available_forecasts <- get_forecast_counts(
- example_quantile,
+ na.omit(example_quantile),
by = c("model", "target_type", "target_end_date")
)
p <- plot_forecast_counts(available_forecasts,
diff --git a/tests/testthat/test-plot_interval_coverage.R b/tests/testthat/test-plot_interval_coverage.R
index 5ff1fbf98..f8f3461ad 100644
--- a/tests/testthat/test-plot_interval_coverage.R
+++ b/tests/testthat/test-plot_interval_coverage.R
@@ -1,5 +1,5 @@
test_that("plot_interval_coverage() works as expected", {
- coverage <- add_coverage(example_quantile) %>%
+ coverage <- add_coverage(na.omit(example_quantile)) %>%
summarise_scores(by = c("model", "range"))
p <- plot_interval_coverage(coverage)
expect_s3_class(p, "ggplot")
diff --git a/tests/testthat/test-plot_quantile_coverage.R b/tests/testthat/test-plot_quantile_coverage.R
index 1851ccb5c..9b44b8bf1 100644
--- a/tests/testthat/test-plot_quantile_coverage.R
+++ b/tests/testthat/test-plot_quantile_coverage.R
@@ -1,5 +1,5 @@
test_that("plot_quantile_coverage() works as expected", {
- coverage <- add_coverage(example_quantile) %>%
+ coverage <- add_coverage(na.omit(example_quantile)) %>%
summarise_scores(by = c("model", "quantile"))
p <- plot_quantile_coverage(coverage)
diff --git a/tests/testthat/test-plot_ranges.R b/tests/testthat/test-plot_ranges.R
index 9a18cffe3..36c2ea41c 100644
--- a/tests/testthat/test-plot_ranges.R
+++ b/tests/testthat/test-plot_ranges.R
@@ -1,6 +1,7 @@
m <- modifyList(metrics_no_cov_no_ae, list("bias" = NULL))
sum_scores <- copy(example_quantile) %>%
+ na.omit() %>%
.[, interval_range := scoringutils:::get_range_from_quantile(quantile)] %>%
score(metrics = m) %>%
summarise_scores(by = c("model", "target_type", "interval_range"))
diff --git a/tests/testthat/test-score.R b/tests/testthat/test-score.R
index ac6249528..29e9cf9c1 100644
--- a/tests/testthat/test-score.R
+++ b/tests/testthat/test-score.R
@@ -42,7 +42,7 @@ test_that("score.forecast_binary() errors with only NA values", {
only_nas <- copy(example_binary)[, predicted := NA_real_]
expect_error(
score(only_nas),
- "After removing NA values in `observed` and `predicted`, there were no observations left"
+ "After removing rows with NA values in the data, no forecasts are left."
)
})
@@ -160,7 +160,7 @@ test_that("score.forecast_point() errors with only NA values", {
only_nas <- copy(example_point)[, predicted := NA_real_]
expect_error(
score(only_nas),
- "After removing NA values in `observed` and `predicted`, there were no observations left"
+ "After removing rows with NA values in the data, no forecasts are left."
)
})
@@ -243,7 +243,7 @@ test_that("score.forecast_quantile() errors with only NA values", {
only_nas <- copy(example_quantile)[, predicted := NA_real_]
expect_error(
score(only_nas),
- "After removing NA values in `observed` and `predicted`, there were no observations left"
+ "After removing rows with NA values in the data, no forecasts are left."
)
})
@@ -259,7 +259,7 @@ test_that("function produces output for a continuous format case", {
only_nas <- copy(example_continuous)[, predicted := NA_real_]
expect_error(
score(only_nas),
- "After removing NA values in `observed` and `predicted`, there were no observations left"
+ "After removing rows with NA values in the data, no forecasts are left."
)
expect_equal(