-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Elr/rel scores #69
Elr/rel scores #69
Changes from all commits
120cd25
4eeb8ed
0234249
1d81151
94d34cc
462adb6
a6f22d1
bec6883
c233d32
4afd00a
4cd6145
36a810f
1718bcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,3 +100,29 @@ error_if_invalid_output_type <- function(output_type) { | |
) | ||
} | ||
} | ||
|
||
|
||
#' Validate relative metrics | ||
#' | ||
#' @noRd | ||
validate_relative_metrics <- function(relative_metrics, metrics, by) { | ||
if (any(is_interval_coverage_metric(relative_metrics))) { | ||
cli::cli_abort( | ||
"Interval coverage metrics are not supported for relative skill scores." | ||
) | ||
} | ||
|
||
if (length(relative_metrics) > 0 && !"model_id" %in% by) { | ||
cli::cli_abort( | ||
"Relative metrics require 'model_id' to be included in {.arg by}." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this strictly necessary? If we know that "model_id' always needs to be included in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't manually checked out the pr and run the code, so this is just from a cursory reading on github - let me know if you'd like me to dig deeper, happy to make a suggestion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm good with a cursory review on the level of "this is reasonable or not", given that Zhian is also reviewing. Regarding your comment above: I don't think this is strictly necessary, but my general preference is to throw errors guiding users towards what we're expecting rather than modify their inputs. I think all of this is clear to you already, but just to say it:
|
||
) | ||
} | ||
|
||
extra_metrics <- setdiff(relative_metrics, metrics) | ||
if (length(extra_metrics) > 0) { | ||
cli::cli_abort(c( | ||
"Relative metrics must be a subset of the metrics.", | ||
"x" = "The following {.arg relative_metrics} are not in {.arg metrics}: {.val {extra_metrics}}" | ||
)) | ||
} | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
test_that("score_model_out succeeds with valid inputs: quantile output_type, relative wis and ae, no baseline", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
act_scores <- score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("ae_median", "wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
by = c("model_id", "location") | ||
) | ||
zkamvar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
exp_scores <- read.csv(test_path("testdata", "exp_pairwise_scores.csv")) |> | ||
dplyr::mutate(location = as.character(location)) |> | ||
dplyr::select(-ae_median_scaled_relative_skill, -wis_scaled_relative_skill) | ||
|
||
expect_equal(act_scores, exp_scores, ignore_attr = TRUE) | ||
}) | ||
|
||
|
||
test_that("score_model_out succeeds with valid inputs: quantile output_type, relative wis and ae, Flusight-baseline", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
act_scores <- score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("ae_median", "wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
baseline = "Flusight-baseline", | ||
by = c("model_id", "location") | ||
) | ||
zkamvar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
exp_scores <- read.csv(test_path("testdata", "exp_pairwise_scores.csv")) |> | ||
dplyr::mutate(location = as.character(location)) | ||
|
||
expect_equal(act_scores, exp_scores, ignore_attr = TRUE) | ||
}) | ||
|
||
|
||
test_that("score_model_out errors when invalid relative metrics are requested", { | ||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
# not allowed to compute relative skill for interval coverage | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("interval_coverage_90", "wis"), | ||
), | ||
regexp = "Interval coverage metrics are not supported for relative skill scores." | ||
) | ||
|
||
# relative_metrics must be a subset of metrics | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = c("ae_median", "wis"), | ||
), | ||
regexp = "Relative metrics must be a subset of the metrics." | ||
) | ||
|
||
# can't ask for relative metrics without breaking down by model_id | ||
expect_error( | ||
score_model_out( | ||
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), | ||
oracle_output = forecast_oracle_output, | ||
metrics = c("wis", "interval_coverage_80", "interval_coverage_90"), | ||
relative_metrics = "wis", | ||
by = "location" | ||
), | ||
regexp = "Relative metrics require 'model_id' to be included in `by`." | ||
) | ||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"model_id","location","ae_median","wis","interval_coverage_80","interval_coverage_90","ae_median_relative_skill","ae_median_scaled_relative_skill","wis_relative_skill","wis_scaled_relative_skill" | ||
"Flusight-baseline","25",210.75,179.944642857143,0,0.25,1.09843550132001,1,1.12612125846186,1 | ||
"Flusight-baseline","48",593,478.964285714286,0,0,1.12969637511783,1,1.15711681669639,1 | ||
"MOBS-GLEAM_FLUH","25",196.125,162.6875,0.5,0.5,1.02220955016079,0.930604982206406,1.01812340354839,0.904097490299596 | ||
"MOBS-GLEAM_FLUH","48",636.625,467.791071428571,0.5,0.625,1.2128043082789,1.07356661045531,1.13012375159286,0.976672134814704 | ||
"PSI-DICE","25",170.875,139.369642857143,0.625,0.625,0.890605771236332,0.81079478054567,0.872196666228429,0.774513987436612 | ||
"PSI-DICE","48",383.125,316.535714285714,0.375,0.375,0.729873395812849,0.646079258010118,0.764710039995535,0.660875400790396 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#' Geometric mean | ||
#' (x_1 \times x_2 \times \ldots \times x_n)^{1/n} | ||
#' = exp[1/n \sum_{i=1}^{n} log(x_i)] | ||
geometric_mean <- function(x) { | ||
exp(mean(log(x))) | ||
} | ||
|
||
|
||
#' Helper function manually computes pairwise relative skill scores by location | ||
#' Called from tests in test-score_model_out_rel_metrics.R | ||
get_pairwise_scores_by_loc <- function(scores_per_task, metric, baseline) { | ||
mean_scores_by_loc <- scores_per_task |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", "location")))) |> | ||
dplyr::summarize( | ||
mean_score = mean(.data[[metric]], na.rm = TRUE), # nolint: object_usage_linter | ||
.groups = "drop" | ||
) | ||
|
||
pairwise_score_ratios <- expand.grid( | ||
model_id = unique(mean_scores_by_loc$model_id), | ||
model_id_compare = unique(mean_scores_by_loc$model_id), | ||
location = unique(mean_scores_by_loc[["location"]]) | ||
) |> | ||
dplyr::left_join(mean_scores_by_loc, by = c("model_id" = "model_id", "location")) |> | ||
dplyr::left_join(mean_scores_by_loc, by = c("model_id_compare" = "model_id", "location")) |> | ||
dplyr::mutate( | ||
pairwise_score_ratio = .data[["mean_score.x"]] / .data[["mean_score.y"]] # nolint: object_usage_linter | ||
) | ||
|
||
result <- pairwise_score_ratios |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", "location")))) |> | ||
dplyr::summarize( | ||
relative_skill = geometric_mean(.data[["pairwise_score_ratio"]]), # nolint: object_usage_linter | ||
.groups = "drop" | ||
) |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of("location"))) |> | ||
dplyr::mutate( | ||
scaled_relative_skill = .data[["relative_skill"]] / | ||
.data[["relative_skill"]][.data[["model_id"]] == baseline] | ||
) | ||
|
||
colnames(result) <- c("model_id", "location", | ||
paste0(metric, "_relative_skill"), | ||
paste0(metric, "_scaled_relative_skill")) | ||
|
||
return(result) | ||
} | ||
|
||
|
||
# Forecast data from hubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html> | ||
forecast_outputs <- hubExamples::forecast_outputs | ||
forecast_oracle_output <- hubExamples::forecast_oracle_output | ||
|
||
# expected scores | ||
exp_scores_unsummarized <- forecast_outputs |> | ||
dplyr::filter(.data[["output_type"]] == "quantile") |> | ||
dplyr::left_join( | ||
forecast_oracle_output |> | ||
dplyr::filter(.data[["output_type"]] == "quantile") |> | ||
dplyr::select(-dplyr::all_of(c("output_type", "output_type_id"))), | ||
by = c("location", "target_end_date", "target") | ||
) |> | ||
dplyr::mutate( | ||
output_type_id = as.numeric(.data[["output_type_id"]]), | ||
qs = ifelse( | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
.data[["output_type_id"]] * (.data[["oracle_value"]] - .data[["value"]]), | ||
(1 - .data[["output_type_id"]]) * (.data[["value"]] - .data[["oracle_value"]]) | ||
), | ||
q_coverage_80_lower = ifelse( | ||
.data[["output_type_id"]] == 0.1, | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_80_upper = ifelse( | ||
.data[["output_type_id"]] == 0.9, | ||
.data[["oracle_value"]] <= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_90_lower = ifelse( | ||
.data[["output_type_id"]] == 0.05, | ||
.data[["oracle_value"]] >= .data[["value"]], | ||
NA_real_ | ||
), | ||
q_coverage_90_upper = ifelse( | ||
.data[["output_type_id"]] == 0.95, | ||
.data[["oracle_value"]] <= .data[["value"]], | ||
NA_real_ | ||
) | ||
) |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of( | ||
c("model_id", "location", "reference_date", "horizon", "target_end_date", "target") | ||
))) |> | ||
dplyr::summarize( | ||
ae_median = sum(ifelse( | ||
.data[["output_type_id"]] == 0.5, | ||
abs(.data[["oracle_value"]] - .data[["value"]]), | ||
0 | ||
)), | ||
wis = 2 * mean(.data[["qs"]]), | ||
interval_coverage_80 = (sum(.data[["q_coverage_80_lower"]], na.rm = TRUE) == 1) * | ||
(sum(.data[["q_coverage_80_upper"]], na.rm = TRUE) == 1), | ||
interval_coverage_90 = (sum(.data[["q_coverage_90_lower"]], na.rm = TRUE) == 1) * | ||
(sum(.data[["q_coverage_90_upper"]], na.rm = TRUE) == 1) | ||
) | ||
|
||
exp_scores_standard <- exp_scores_unsummarized |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of( | ||
c("model_id", "location") | ||
))) |> | ||
dplyr::summarize( | ||
ae_median = mean(.data[["ae_median"]]), | ||
wis = mean(.data[["wis"]]), | ||
interval_coverage_80 = mean(.data[["interval_coverage_80"]], na.rm = TRUE), | ||
interval_coverage_90 = mean(.data[["interval_coverage_90"]], na.rm = TRUE), | ||
.groups = "drop" | ||
) | ||
|
||
# add pairwise relative scores for ae_median and wis | ||
exp_scores_relative_ae_median <- get_pairwise_scores_by_loc(exp_scores_unsummarized, "ae_median", "Flusight-baseline") | ||
exp_scores_relative_wis <- get_pairwise_scores_by_loc(exp_scores_unsummarized, "wis", "Flusight-baseline") | ||
exp_scores <- exp_scores_standard |> | ||
dplyr::full_join(exp_scores_relative_ae_median, by = c("model_id", "location")) |> | ||
dplyr::full_join(exp_scores_relative_wis, by = c("model_id", "location")) | ||
|
||
# save | ||
write.csv(exp_scores, testthat::test_path("testdata", "exp_pairwise_scores.csv"), row.names = FALSE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the alternative to enforcing this requirement is that we add another argument along the lines of
compare
as is used inscoringutils::get_pairwise_comparisons
, setting a default of"model_id"
. I think that would be fine, but since essentially all use cases of this function will include"model_id"
inby
, I don't think it's necessarily worth introducing the extra argument here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current approach is fine since it would be caught by the validation. The problem with extra arguments that affect other arguments is that it becomes difficult for users to remember the relationships between them.