Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #404: Use na.omit() to remove NA values before scoring #465

Merged
merged 17 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ The update introduces breaking changes. If you want to keep using the older vers
- `plot_avail_forecasts()` was renamed `plot_forecast_counts()` in line with the change in the function name. The `x` argument no longer has a default value, as the value will depend on the data provided by the user.
- The deprecated `..density..` was replaced with `after_stat(density)` in ggplot calls.
- Files ending in ".Rda" were renamed to ".rds" where appropriate when used together with `saveRDS()` or `readRDS()`.
- added documentation for the return value of `summarise_scores()`.
- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
- added documentation for the return value of `summarise_scores()`.


# scoringutils 1.2.1

Expand Down
2 changes: 1 addition & 1 deletion R/available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ get_forecast_counts <- function(data,

data <- as_forecast(data)
forecast_unit <- attr(data, "forecast_unit")
data <- remove_na_observed_predicted(data)
data <- na.omit(data)

if (is.null(by)) {
by <- forecast_unit
Expand Down
2 changes: 1 addition & 1 deletion R/pit.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ pit <- function(data,
n_replicates = 100) {

data <- as_forecast(data)
data <- remove_na_observed_predicted(data)
data <- na.omit(data)
forecast_type <- get_forecast_type(data)

if (forecast_type == "quantile") {
Expand Down
16 changes: 10 additions & 6 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#' unsummarised scores to obtain one score per forecast unit for quantile-based
#' forecasts.
#' @importFrom data.table ':=' as.data.table
#' @importFrom stats na.omit
#' @examples
#' library(magrittr) # pipe operator
#' data.table::setDTthreads(1) # only needed to avoid issues on CRAN
Expand Down Expand Up @@ -74,11 +75,12 @@ score.default <- function(data, ...) {
score(data, ...)
}

#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)

data <- apply_rules(
Expand All @@ -94,11 +96,12 @@ score.forecast_binary <- function(data, metrics = metrics_binary, ...) {


#' @importFrom Metrics se ae ape
#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_point <- function(data, metrics = metrics_point, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)

data <- apply_rules(
Expand All @@ -111,11 +114,12 @@ score.forecast_point <- function(data, metrics = metrics_point, ...) {
return(data[])
}

#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)

Expand Down Expand Up @@ -147,12 +151,13 @@ score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
return(data[])
}

#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table rbindlist %like%
#' @rdname score
#' @export
score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)

Expand All @@ -163,8 +168,7 @@ score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
observed = unique(observed),
quantile = list(sort(quantile, na.last = TRUE)),
scoringutils_quantile = toString(sort(quantile, na.last = TRUE))
),
by = forecast_unit]
), by = forecast_unit]

# split according to quantile lengths and do calculations for different
# quantile lengths separately. The function `wis()` assumes that all
Expand Down
1 change: 0 additions & 1 deletion R/summarise_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#' summarise_scores(scores)
#' }
#'
#'
#' # summarise over samples or quantiles to get one score per forecast
#' scores <- score(example_quantile)
#' summarise_scores(scores)
Expand Down
12 changes: 0 additions & 12 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,12 @@
#' @return A vector with the name of all available metrics
#' @export
#' @keywords info

available_metrics <- function() {
return(unique(c(scoringutils::metrics$Name,
"wis", "coverage_50", "coverage_90")))
}


remove_na_observed_predicted <- function(data) {
# remove rows where predicted or observed value are NA -----------------------
data <- data[!is.na(observed) & !is.na(predicted)]
if (nrow(data) == 0) {
stop("After removing NA values in `observed` and `predicted`, there were no observations left")
}
return(data[])
}


#' @title Collapse several messages to one
#'
#' @description Internal helper function to facilitate generating messages
Expand All @@ -38,7 +27,6 @@ collapse_messages <- function(type = "messages", messages) {
}



#' @title Print output from `check_forecasts()`
#'
#' @description Helper function that prints the output generated by
Expand Down
19 changes: 11 additions & 8 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,17 @@ validate_general <- function(data) {
setattr(data, "warnings", number_quantiles_samples)
}

# check whether there are any NA values in the predicted or observed values
messages <- c(
check_no_NA_present(data, "predicted"),
check_no_NA_present(data, "observed")
)
if (!is.logical(messages)) {
messages <- messages[messages != "TRUE"]
setattr(data, "messages", messages)
# check whether there are any NA values
if (anyNA(data)) {
if (nrow(na.omit(data)) == 0) {
stop(
"After removing rows with NA values in the data, no forecasts are left."
)
}
message(
"Some rows containing NA values may be removed. ",
seabbs marked this conversation as resolved.
Show resolved Hide resolved
"This is fine if not unexpected."
)
}

return(data[])
Expand Down
1 change: 0 additions & 1 deletion man/summarise_scores.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion tests/testthat/test-add_coverage.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"]

test_that("add_coverage() works as expected", {
expect_no_condition(cov <- add_coverage(example_quantile))
expect_message(
cov <- add_coverage(example_quantile),
"Some rows containing NA values may be removed."
)

required_names <- c(
"range", "interval_coverage", "interval_coverage_deviation",
Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test-available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ test_that("get_forecast_counts() works as expected", {
expect_type(af$target_type, "character")
expect_type(af$`count`, "integer")
expect_equal(nrow(af[is.na(`count`)]), 0)
af <- get_forecast_counts(example_quantile, by = "model")
af <- get_forecast_counts(na.omit(example_quantile), by = "model")
expect_equal(nrow(af), 4)
expect_equal(af$`count`, c(256, 256, 128, 247))

# Setting `collapse = c()` means that all quantiles and samples are counted
af <- get_forecast_counts(
example_quantile,
na.omit(example_quantile),
by = "model", collapse = c()
)
expect_equal(nrow(af), 4)
expect_equal(af$`count`, c(5888, 5888, 2944, 5681))

# setting by = NULL, the default, results in by equal to forecast unit
af <- get_forecast_counts(example_quantile)
af <- get_forecast_counts(na.omit(example_quantile))
expect_equal(nrow(af), 50688)

# check whether collapsing also works for model-based forecasts
af <- get_forecast_counts(example_integer, by = "model")
af <- get_forecast_counts(na.omit(example_integer), by = "model")
expect_equal(nrow(af), 4)

af <- get_forecast_counts(example_integer, by = "model", collapse = c())
af <- get_forecast_counts(na.omit(example_integer), by = "model", collapse = c())
expect_equal(af$count, c(10240, 10240, 5120, 9880))
})
2 changes: 1 addition & 1 deletion tests/testthat/test-check_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ test_that("as_forecast() function throws an error when no predictions or observe
test_that("output of check_forecasts() is accepted as input to score()", {
check <- suppressMessages(as_forecast(example_binary))
expect_no_error(
score_check <- score(check)
score_check <- score(na.omit(check))
)
expect_equal(score_check, suppressMessages(score(example_binary)))
})
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ test_that("function set_forecast_unit() works", {
example_quantile,
c("location", "target_end_date", "target_type", "horizon", "model")
)
scores2 <- score(ex2)
scores2 <- score(na.omit(ex2))
scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ]

expect_equal(scores1$interval_score, scores2$interval_score)
Expand Down
35 changes: 32 additions & 3 deletions tests/testthat/test-get_-functions.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ==============================================================================
# `get_forecast_unit()`
# ==============================================================================

test_that("get_forecast_unit() works as expected", {
expect_equal(
get_forecast_unit(example_quantile),
Expand All @@ -15,7 +14,7 @@ test_that("get_forecast_unit() works as expected", {
"forecast_date", "model", "horizon")
)

data <- as_forecast(example_quantile)
data <- as_forecast(na.omit(example_quantile))
ex <- data[, location := NULL]
expect_warning(
get_forecast_unit(ex, check_conflict = TRUE),
Expand All @@ -28,6 +27,36 @@ fixed = TRUE
})


# ==============================================================================
# Test removing `NA` values from the data
# ==============================================================================
test_that("removing NA rows from data works as expected", {
expect_equal(nrow(na.omit(example_quantile)), 20401)

ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4)
expect_equal(nrow(na.omit(ex)), 3)

ex$predicted <- c(1:3, NA)
expect_equal(nrow(na.omit(ex)), 2)

# test that attributes and classes are retained
ex <- as_forecast(na.omit(example_integer))
expect_equal(
class(na.omit(ex)),
c("forecast_sample", "data.table", "data.frame")
)

attributes <- get_scoringutils_attributes(ex)
expect_equal(
get_scoringutils_attributes(na.omit(ex)),
attributes
)
})


# ==============================================================================
# `get_type()`
# ==============================================================================
test_that("get_type() works as expected with vectors", {
expect_equal(get_type(1:3), "integer")
expect_equal(get_type(factor(1:2)), "classification")
Expand Down Expand Up @@ -164,7 +193,7 @@ test_that("get_forecast_type() works as expected", {
fixed = TRUE
)

data <- as_forecast(example_integer)
data <- as_forecast(na.omit(example_integer))
attr(data, "forecast_type") <- "binary"
expect_warning(
get_forecast_type(data),
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-plot_avail_forecasts.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("plot.forecast_counts() works as expected", {
available_forecasts <- get_forecast_counts(
example_quantile,
na.omit(example_quantile),
by = c("model", "target_type", "target_end_date")
)
p <- plot_forecast_counts(available_forecasts,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-plot_interval_coverage.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("plot_interval_coverage() works as expected", {
coverage <- add_coverage(example_quantile) %>%
coverage <- add_coverage(na.omit(example_quantile)) %>%
summarise_scores(by = c("model", "range"))
p <- plot_interval_coverage(coverage)
expect_s3_class(p, "ggplot")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-plot_quantile_coverage.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("plot_quantile_coverage() works as expected", {
coverage <- add_coverage(example_quantile) %>%
coverage <- add_coverage(na.omit(example_quantile)) %>%
summarise_scores(by = c("model", "quantile"))

p <- plot_quantile_coverage(coverage)
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-plot_ranges.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
m <- modifyList(metrics_no_cov_no_ae, list("bias" = NULL))

sum_scores <- copy(example_quantile) %>%
na.omit() %>%
.[, interval_range := scoringutils:::get_range_from_quantile(quantile)] %>%
score(metrics = m) %>%
summarise_scores(by = c("model", "target_type", "interval_range"))
Expand Down
Loading
Loading