Skip to content

Commit

Permalink
Create plot_forecast_counts() to replace previous S3 method
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Dec 5, 2023
1 parent c48d6dd commit 18f42a8
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 48 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,prediction_counts)
S3method(print,scoringutils_check)
S3method(quantile_to_interval,data.frame)
S3method(quantile_to_interval,numeric)
Expand Down Expand Up @@ -49,6 +48,7 @@ export(pairwise_comparison)
export(pit)
export(pit_sample)
export(plot_correlation)
export(plot_forecast_counts)
export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparison)
Expand Down
2 changes: 0 additions & 2 deletions R/available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,5 @@ get_forecast_counts <- function(data,
out <- merge(out, out_empty, by = by, all.y = TRUE)
out[, count := nafill(count, fill = 0)]

class(out) <- c("prediction_counts", class(out))

return(out[])
}
45 changes: 21 additions & 24 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -942,13 +942,13 @@ plot_pit <- function(pit,
#' @description
#' Visualise Where Forecasts Are Available
#' @inheritParams print.scoringutils_check
#' @param x an S3 object of class "prediction_counts"
#' @param forecast_counts a data.table (or similar) with forecast counts
#' as produced by [get_forecast_counts()]
#' @param yvar character vector of length one that denotes the name of the column
#' @param y character vector of length one that denotes the name of the column
#' to appear on the y-axis of the plot. Default is "model".
#' @param xvar character vector of length one that denotes the name of the column
#' to appear on the x-axis of the plot. Default is "forecast_date".
#' @param make_xvar_factor logical (default is TRUE). Whether or not to convert
#' @param x character vector of length one that denotes the name of the column
#' to appear on the x-axis of the plot.
#' @param make_x_factor logical (default is TRUE). Whether or not to convert
#' the variable on the x-axis to a factor. This has an effect e.g. if dates
#' are shown on the x-axis.
#' @param show_numbers logical (default is `TRUE`) that indicates whether
Expand All @@ -960,35 +960,34 @@ plot_pit <- function(pit,
#' @export
#' @examples
#' library(ggplot2)
#' available_forecasts <- get_forecast_counts(
#' forecast_counts <- get_forecast_counts(
#' example_quantile, by = c("model", "target_type", "target_end_date")
#' )
#' plot(
#' available_forecasts, xvar = "target_end_date", show_numbers = FALSE
#' plot_forecast_counts(
#' forecast_counts, x = "target_end_date", show_numbers = FALSE
#' ) +
#' facet_wrap("target_type")

plot.prediction_counts <- function(x,
yvar = "model",
xvar = "forecast_date",
make_xvar_factor = TRUE,
show_numbers = TRUE,
...) {
x <- as.data.table(x)
plot_forecast_counts <- function(forecast_counts,
y = "model",
x,
make_x_factor = TRUE,
show_numbers = TRUE) {

if (make_xvar_factor) {
x[, eval(xvar) := as.factor(get(xvar))]
forecast_counts <- ensure_data.table(forecast_counts)

if (make_x_factor) {
forecast_counts[, eval(x) := as.factor(get(x))]
}

setnames(x, old = "count", new = "Count")
setnames(forecast_counts, old = "count", new = "Count")

plot <- ggplot(
x,
aes(y = .data[[yvar]], x = .data[[xvar]])
forecast_counts,
aes(y = .data[[y]], x = .data[[x]])
) +
geom_tile(aes(fill = `Count`),
width = 0.97, height = 0.97
) +
width = 0.97, height = 0.97) +
scale_fill_gradient(
low = "grey95", high = "steelblue",
na.value = "lightgrey"
Expand All @@ -1001,12 +1000,10 @@ plot.prediction_counts <- function(x,
)
) +
theme(panel.spacing = unit(2, "lines"))

if (show_numbers) {
plot <- plot +
geom_text(aes(label = `Count`))
}

return(plot)
}

Expand Down
33 changes: 15 additions & 18 deletions man/plot.prediction_counts.Rd → man/plot_forecast_counts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-plot_avail_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ test_that("plot.forecast_counts() works as expected", {
example_quantile,
by = c("model", "target_type", "target_end_date")
)
p <- plot(available_forecasts,
xvar = "target_end_date", show_numbers = FALSE
p <- plot_forecast_counts(available_forecasts,
x = "target_end_date", show_numbers = FALSE
) +
facet_wrap("target_type")
expect_s3_class(p, "ggplot")
Expand Down
2 changes: 1 addition & 1 deletion vignettes/scoringutils.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ This information can also be visualised using `plot()`:
```{r, fig.width=11, fig.height=6}
example_quantile %>%
get_forecast_counts(by = c("model", "forecast_date", "target_type")) %>%
plot() +
plot_forecast_counts(x = "forecast_date") +
facet_wrap(~ target_type)
```

Expand Down

0 comments on commit 18f42a8

Please sign in to comment.