Skip to content

Commit

Permalink
Merge pull request #368 from epiforecasts/rework-summarise-scores
Browse files Browse the repository at this point in the history
Rework `summarise_scores()` and `pairwise_comparison()`
  • Loading branch information
nikosbosse authored Nov 7, 2023
2 parents cb42edb + 59eff74 commit c745f9f
Show file tree
Hide file tree
Showing 37 changed files with 498 additions and 269 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ S3method(validate,scoringutils_quantile)
S3method(validate,scoringutils_sample)
export(abs_error)
export(add_coverage)
export(add_pairwise_comparison)
export(ae_median_quantile)
export(ae_median_sample)
export(avail_forecasts)
Expand Down Expand Up @@ -66,6 +67,7 @@ importFrom(Metrics,ae)
importFrom(Metrics,ape)
importFrom(Metrics,se)
importFrom(checkmate,assert)
importFrom(checkmate,assert_character)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_data_table)
importFrom(checkmate,assert_factor)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The update introduces breaking changes. If you want to keep using the older vers
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- `quantile`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- `check_forecasts()` was replaced by a new function `validate()`. `validate()` validates the input and in that sense fulfills the purpose of `check_forecasts()`. It has different methods: `validate.default()` assigns the input a class based on their forecast type. Other methods validate the input specifically for the various forecast types.
- The functionality for computing pairwise comparisons was now split from `summarise_scores()`. Instead of doing pairwise comparisons as part of summarising scores, a new function, `add_pairwise_comparison()`, was introduced that takes summarised scores as an input and adds pairwise comparisons to it.
- The function `find_duplicates()` was renamed to `get_duplicate_forecasts()`
- Changes to `avail_forecasts()` and `plot_avail_forecasts()`:
- The function `avail_forecasts()` was renamed to `available_forecasts()` for consistency with `available_metrics()`. The old function, `avail_forecasts()` is still available as an alias, but will be removed in the future.
Expand Down
26 changes: 24 additions & 2 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,17 @@ check_duplicates <- function(data, forecast_unit = NULL) {
#' @param columns names of columns to be checked
#' @return Returns string with a message with the first issue encountered if
#' any of the column names are not in data, otherwise returns TRUE
#'
#' @importFrom checkmate assert_character
#' @keywords check-inputs
check_columns_present <- function(data, columns) {
if (is.null(columns)) {
return(TRUE)
}
assert_character(columns, min.len = 1)
colnames <- colnames(data)
for (x in columns){
if (!(x %in% colnames)) {
msg <- paste0("Data needs to have a column called '", x, "'")
msg <- paste0("Column '", x, "' not found in data")
return(msg)
}
}
Expand Down Expand Up @@ -388,3 +392,21 @@ check_data_columns <- function(data) {
}


#' Check whether an attribute is present
#' @description Checks whether an object has an attribute
#' @param object An object to be checked
#' @param attribute name of an attribute to be checked
#' @return Returns TRUE if attribute is there and an error message as
#' a string otherwise
#' @keywords check-inputs
check_has_attribute <- function(object, attribute) {
if (is.null(attr(object, attribute))) {
return(
paste0("Found no attribute `", attribute, "`")
)
} else {
return(TRUE)
}
}


59 changes: 56 additions & 3 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,40 @@ get_target_type <- function(data) {
}


#' @title Get metrics that were used for scoring
#'
#' @description Internal helper function to get the metrics that were used
#' to score forecasts.
#' @param score A data.table with an attribute `metric_names`
#'
#' @return Character vector with the metrics that were used for scoring.
#'
#' @keywords internal

get_metrics <- function(scores) {
metric_names <- attr(scores, "metric_names")
if (is.null(metric_names)) {
stop("The data needs to have an attribute `metric_names` with the names ",
" of the metrics that were used for scoring. This should be the case ",
"if the data was produced using `score()`. Either run `score()` ",
"again, or set the attribute manually using ",
"`attr(data, 'metric_names') <- names_of_the_scoring_metrics")
}
return(metric_names)
}





#' @title Get unit of a single forecast
#'
#' @description Helper function to get the unit of a single forecast, i.e.
#' the column names that define where a single forecast was made for.
#' This just takes all columns that are available in the data and subtracts
#' the columns that are protected, i.e. those returned by
#' [get_protected_columns()].
#' [get_protected_columns()] as well as the names of the metrics that were
#' specified during scoring, if any.
#'
#' @inheritParams validate
#'
Expand All @@ -144,7 +170,9 @@ get_target_type <- function(data) {

get_forecast_unit <- function(data) {
protected_columns <- get_protected_columns(data)
forecast_unit <- setdiff(colnames(data), protected_columns)
protected_columns <- c(protected_columns, attr(data, "metric_names"))

forecast_unit <- setdiff(colnames(data), unique(protected_columns))
return(forecast_unit)
}

Expand All @@ -166,7 +194,8 @@ get_protected_columns <- function(data = NULL) {

protected_columns <- c(
"predicted", "observed", "sample_id", "quantile", "upper", "lower",
"pit_value", "range", "boundary", available_metrics(),
"pit_value", "range", "boundary", "relative_skill", "scaled_rel_skill",
available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)

Expand Down Expand Up @@ -215,3 +244,27 @@ get_duplicate_forecasts <- function(data, forecast_unit = NULL) {
out[, scoringutils_InternalDuplicateCheck := NULL]
return(out[])
}


#' @title Get a list of all attributes of a scoringutils object
#'
#' @param object A object of class `scoringutils_`
#'
#' @return A named list with the attributes of that object.
#' @keywords internal
get_scoringutils_attributes <- function(object) {
possible_attributes <- c(
"scoringutils_by",
"forecast_unit",
"forecast_type",
"metric_names",
"messages",
"warnings"
)

attr_list <- list()
for (attr_name in possible_attributes) {
attr_list[[attr_name]] <- attr(object, attr_name)
}
return(attr_list)
}
8 changes: 6 additions & 2 deletions R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#' @param metric A character vector of length one with the metric to do the
#' comparison on. The default is "auto", meaning that either "interval_score",
#' "crps", or "brier_score" will be selected where available.
#' See [available_metrics()] for available metrics.
#' @param by character vector with names of columns present in the input
#' data.frame. `by` determines how pairwise comparisons will be computed.
#' You will get a relative skill score for every grouping level determined in
Expand Down Expand Up @@ -67,9 +66,14 @@ pairwise_comparison <- function(scores,
baseline = NULL,
...) {

# metric_names <- get_metrics(scores)
metric <- match.arg(metric, c("auto", available_metrics()))

scores <- data.table::as.data.table(scores)
if (!is.data.table(scores)) {
scores <- as.data.table(scores)
} else {
scores <- copy(scores)
}

# determine metric automatically
if (metric == "auto") {
Expand Down
4 changes: 1 addition & 3 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ plot_score_table <- function(scores,

# identify metrics -----------------------------------------------------------
id_vars <- get_forecast_unit(scores)
if (is.null(metrics)) {
metrics <- names(scores)[names(scores) %in% available_metrics()]
}
metrics <- get_metrics(scores)

scores <- delete_columns(
scores,
Expand Down
15 changes: 12 additions & 3 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@
#' [example_integer], [example_point()], and [example_binary]).
#'
#' @param metrics the metrics you want to have in the output. If `NULL` (the
#' default), all available metrics will be computed. For a list of available
#' metrics see [available_metrics()], or check the [metrics] data set.
#'
#' default), all available metrics will be computed.
#' @param ... additional parameters passed down to other functions.
#'
#' @return A data.table with unsummarised scores. There will be one score per
Expand Down Expand Up @@ -129,6 +127,8 @@ score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) {
return()
}, ...)

setattr(data, "metric_names", names(metrics))

return(data[])

}
Expand Down Expand Up @@ -156,6 +156,8 @@ score.scoringutils_point <- function(data, metrics = metrics_point, ...) {
return()
}, ...)

setattr(data, "metric_names", names(metrics))

return(data[])
}

Expand Down Expand Up @@ -187,6 +189,7 @@ score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) {
by = forecast_unit
]

setattr(data, "metric_names", names(metrics))

return(data[])
}
Expand All @@ -206,5 +209,11 @@ score.scoringutils_quantile <- function(data, metrics = NULL, ...) {
...
)

setattr(scores, "metric_names", metrics[metrics %in% colnames(scores)])
# manual hack to make sure that the correct attributes are there.
setattr(scores, "forecast_unit", forecast_unit)
setattr(scores, "forecast_type", "quantile")
scores <- new_scoringutils(scores, "scoringutils_quantile")

return(scores[])
}
Loading

0 comments on commit c745f9f

Please sign in to comment.