Skip to content

Commit

Permalink
Merge branch 'develop' into expose-functions2
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored Nov 20, 2023
2 parents bbca6c0 + b778f8a commit 0c26191
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 109 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}
- {os: ubuntu-latest, r: '3.5'}
- {os: ubuntu-latest, r: '3.6'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ URL: https://doi.org/10.48550/arXiv.2205.07090, https://epiforecasts.io/scoringu
BugReports: https://github.com/epiforecasts/scoringutils/issues
VignetteBuilder: knitr
Depends:
R (>= 3.5)
R (>= 3.6)
Roxygen: list(markdown = TRUE)
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export(dispersion)
export(dss_sample)
export(get_duplicate_forecasts)
export(get_forecast_unit)
export(get_forecast_type)
export(interval_coverage_deviation_quantile)
export(interval_coverage_quantile)
export(interval_coverage_sample)
Expand Down Expand Up @@ -86,6 +87,7 @@ importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
importFrom(checkmate,check_function)
Expand Down
70 changes: 54 additions & 16 deletions R/check-inputs-scoring-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,21 @@ check_input_interval <- function(observed, lower, upper, range) {
#' that `predicted` represents the probability that the observed value is equal
#' to the highest factor level.
#' @param predicted Input to be checked. `predicted` should be a vector of
#' length n, holding probabilities. Values represent the probability that
#' length n, holding probabilities. Alternatively, `predicted` can be a matrix
#' of size n x 1. Values represent the probability that
#' the corresponding value in `observed` will be equal to the highest
#' available factor level.
#' @importFrom checkmate assert assert_factor
#' @inherit document_assert_functions return
#' @keywords check-inputs
assert_input_binary <- function(observed, predicted) {
if (length(observed) != length(predicted)) {
stop("`observed` and `predicted` need to be ",
"of same length when scoring binary forecasts")
}
assert_factor(observed, n.levels = 2)
levels <- levels(observed)
assert(
check_numeric_vector(predicted, min.len = 1, lower = 0, upper = 1)
)
assert_factor(observed, n.levels = 2, min.len = 1)
assert_numeric(predicted, lower = 0, upper = 1)
assert_dims_ok_point(observed, predicted)
return(invisible(NULL))
}


#' @title Check that inputs are correct for binary forecast
#' @inherit assert_input_binary params description
#' @inherit document_check_functions return
Expand All @@ -200,12 +196,9 @@ check_input_binary <- function(observed, predicted) {
#' @inherit document_assert_functions return
#' @keywords check-inputs
assert_input_point <- function(observed, predicted) {
assert(check_numeric_vector(observed, min.len = 1))
assert(check_numeric_vector(predicted, min.len = 1))
if (length(observed) != length(predicted)) {
stop("`observed` and `predicted` need to be ",
"of same length when scoring point forecasts")
}
assert(check_numeric(observed))
assert(check_numeric(predicted))
assert(check_dims_ok_point(observed, predicted))
return(invisible(NULL))
}

Expand All @@ -217,3 +210,48 @@ check_input_point <- function(observed, predicted) {
result <- check_try(assert_input_point(observed, predicted))
return(result)
}


#' @title Assert Inputs Have Matching Dimensions
#' @description Function assesses whether input dimensions match. In the
#' following, n is the number of observations / forecasts. Scalar values may
#' be repeated to match the length of the other input.
#' Allowed options are therefore
#' - `observed` is vector of length 1 or length n
#' - `predicted` is
#' - a vector of of length 1 or length n
#' - a matrix with n rows and 1 column
#' @inherit assert_input_binary
#' @inherit document_assert_functions return
#' @importFrom checkmate assert_vector check_matrix check_vector assert
#' @keywords check-inputs
assert_dims_ok_point <- function(observed, predicted) {
assert_vector(observed, min.len = 1)
n_obs <- length(observed)
assert(
check_vector(predicted, min.len = 1, strict = TRUE),
check_matrix(predicted, ncols = 1, nrows = n_obs)
)
dim_p <- dim(predicted)
if (!is.null(dim_p) && (length(dim_p) > 1) && (dim_p[2] > 1)) {
stop("`predicted` must be a vector or a matrix with one column. Found ",
dim(predicted)[2], " columns")
}
n_pred <- length(as.vector(predicted))
# check that both are either of length 1 or of equal length
if ((n_obs != 1) && (n_pred != 1) && (n_obs != n_pred)) {
stop("`observed` and `predicted` must either be of length 1 or ",
"of equal length. Found ", n_obs, " and ", n_pred)
}
return(invisible(NULL))
}


#' @title Check Inputs Have Matching Dimensions
#' @inherit assert_dims_ok_point params description
#' @inherit document_check_functions return
#' @keywords check-inputs
check_dims_ok_point <- function(observed, predicted) {
result <- check_try(assert_dims_ok_point(observed, predicted))
return(result)
}
2 changes: 0 additions & 2 deletions R/correlations.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
correlation <- function(scores,
metrics = NULL,
digits = NULL) {
metrics <- check_metrics(metrics)

metrics <- get_metrics(scores)

# if quantile column is present, throw a warning
Expand Down
30 changes: 19 additions & 11 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
# Functions that help to obtain information about the data

#' @title Infer the type of a forecast based on a data.frame
#' @title Infer Forecast Type
#' @description Helper function to infer the forecast type based on a
#' data.frame or similar with predictions. Please check the vignettes to
#' learn more about forecast types.
#'
#' @description Internal helper function to get the type of the forecast.
#' Options are "sample-based", "quantile-based", "binary" or "point" forecast.
#' The function runs additional checks to make sure the data satisfies
#' requirements and throws an informative error if any issues are found.
#'
#' @inheritParams validate
#' Possible forecast types are
#' - "sample-based"
#' - "quantile-based"
#' - "binary"
#' - "point" forecast.
#'
#' The function runs additional checks to make sure the data satisfies the
#' requirements of the respective forecast type and throws an
#' informative error if any issues are found.
#' @inheritParams score
#' @return Character vector of length one with either "binary", "quantile",
#' "sample" or "point".
#'
#' @keywords internal
#' @export
#' @keywords check-forceasts
get_forecast_type <- function(data) {
assert_data_frame(data)
assert(check_columns_present(data, c("observed", "predicted")))
if (test_forecast_type_is_binary(data)) {
forecast_type <- "binary"
} else if (test_forecast_type_is_quantile(data)) {
Expand All @@ -24,8 +32,8 @@ get_forecast_type <- function(data) {
forecast_type <- "point"
} else {
stop(
"Checking `data`: input doesn't satisfy criteria for any forecast type.",
"Are you missing a column `quantile` or `sample_id`?",
"Checking `data`: input doesn't satisfy criteria for any forecast type. ",
"Are you missing a column `quantile` or `sample_id`? ",
"Please check the vignette for additional info."
)
}
Expand Down
40 changes: 40 additions & 0 deletions man/assert_dims_ok_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/assert_input_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions man/check_dims_ok_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/check_input_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 17 additions & 6 deletions man/get_forecast_type.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions tests/testthat/test-get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,39 @@ test_that("get_duplicate_forecasts() works as expected for point", {
22
)
})


# ==============================================================================
# `get_forecast_type`
# ==============================================================================
test_that("get_forecast_type() works as expected", {
expect_equal(get_forecast_type(as.data.frame(example_quantile)), "quantile")
expect_equal(get_forecast_type(example_continuous), "sample")
expect_equal(get_forecast_type(example_integer), "sample")
expect_equal(get_forecast_type(example_binary), "binary")
expect_equal(get_forecast_type(example_point), "point")

expect_error(
get_forecast_type(data.frame(x = 1:10)),
"Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data.",
fixed = TRUE
)

df <- data.frame(observed = 1:10, predicted = factor(1:10))
expect_error(
get_forecast_type(df),
"Checking `data`: input doesn't satisfy criteria for any forecast type. Are you missing a column `quantile` or `sample_id`? Please check the vignette for additional info.",
fixed = TRUE
)

data <- validate(example_integer)
attr(data, "forecast_type") <- "binary"
expect_warning(
get_forecast_type(data),
"Object has an attribute `forecast_type`, but it looks different from what's expected based on the data.
Existing: binary
Expected: sample
Running `validate()` again might solve the problem",
fixed = TRUE
)
})
Loading

0 comments on commit 0c26191

Please sign in to comment.