Skip to content

Commit

Permalink
Merge pull request #669 from epiforecasts/remove-plot_predictions
Browse files Browse the repository at this point in the history
#659 Remove function `plot_predictions()`
  • Loading branch information
nikosbosse authored Feb 29, 2024
2 parents 23726b4 + 6fe83a3 commit 1fea902
Show file tree
Hide file tree
Showing 15 changed files with 3 additions and 2,240 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ LazyData: true
Imports:
checkmate,
data.table,
ggdist (>= 3.2.0),
ggplot2 (>= 3.4.0),
lifecycle,
methods,
Expand Down
6 changes: 0 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ export(log_shift)
export(logs_binary)
export(logs_sample)
export(mad_sample)
export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_forecast)
export(overprediction)
Expand All @@ -61,7 +59,6 @@ export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparison)
export(plot_pit)
export(plot_predictions)
export(plot_quantile_coverage)
export(plot_score_table)
export(plot_wis)
Expand Down Expand Up @@ -129,7 +126,6 @@ importFrom(data.table,setDT)
importFrom(data.table,setattr)
importFrom(data.table,setcolorder)
importFrom(data.table,setnames)
importFrom(ggdist,geom_lineribbon)
importFrom(ggplot2,.data)
importFrom(ggplot2,`%+replace%`)
importFrom(ggplot2,aes)
Expand All @@ -145,7 +141,6 @@ importFrom(ggplot2,geom_col)
importFrom(ggplot2,geom_histogram)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_polygon)
importFrom(ggplot2,geom_text)
importFrom(ggplot2,geom_tile)
Expand All @@ -168,7 +163,6 @@ importFrom(ggplot2,xlab)
importFrom(ggplot2,ylab)
importFrom(lifecycle,deprecated)
importFrom(methods,hasArg)
importFrom(rlang,enexprs)
importFrom(rlang,warn)
importFrom(scoringRules,crps_sample)
importFrom(scoringRules,dss_sample)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The update introduces breaking changes. If you want to keep using the older vers
- Added a method for `print()` that prints out additional information for `forecast` objects.
- Added a subsetting `[` operator for scores, so that the score name attribute gets preserved when subsetting.
- Deleted the function `plot_ranges()`. If you want to continue using the functionality, you can find the function code [here](https://github.com/epiforecasts/scoringutils/issues/462).
- Removed the function `plot_predictions()`, as well as its helper function `make_NA()`, in favour of a dedicated Vignette that shows different ways of visualising predictions. For future reference, the function code can be found [here](https://github.com/epiforecasts/scoringutils/issues/659) (Issue #659).

# scoringutils 1.2.2

Expand Down
240 changes: 2 additions & 238 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,242 +266,6 @@ plot_heatmap <- function(scores,
return(plot)
}

#' @title Plot Predictions vs True Values
#'
#' @description
#' Make a plot of observed and predicted values
#'
#' @param data a data.frame that follows the same specifications outlined in
#' [score()]. To customise your plotting, you can filter your data using the
#' function [make_NA()].
#' @param by character vector with column names that denote categories by which
#' the plot should be stratified. If for example you want to have a facetted
#' plot, this should be a character vector with the columns used in facetting
#' (note that the facetting still needs to be done outside of the function call)
#' @param x character vector of length one that denotes the name of the variable
#' @param interval_range numeric vector indicating the interval ranges to plot.
#' If 0 is included in `interval_range`, the median prediction will be shown.
#' @return ggplot object with a plot of true vs predicted values
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual theme_light
#' @importFrom ggplot2 facet_wrap facet_grid aes geom_line .data geom_point
#' @importFrom data.table dcast
#' @importFrom ggdist geom_lineribbon
#' @export
#' @examples
#' library(ggplot2)
#' library(magrittr)
#'
#' example_continuous %>%
#' make_NA (
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' ) %>%
#' make_NA (
#' what = "forecast",
#' model != "EuroCOVIDhub-ensemble",
#' forecast_date != "2021-06-07"
#' ) %>%
#' plot_predictions (
#' x = "target_end_date",
#' by = c("target_type", "location"),
#' interval_range = c(0, 50, 90, 95)
#' ) +
#' facet_wrap(~ location + target_type, scales = "free_y") +
#' aes(fill = model, color = model)
#'
#' example_continuous %>%
#' make_NA (
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' ) %>%
#' make_NA (
#' what = "forecast",
#' forecast_date != "2021-06-07"
#' ) %>%
#' plot_predictions (
#' x = "target_end_date",
#' by = c("target_type", "location"),
#' interval_range = 0
#' ) +
#' facet_wrap(~ location + target_type, scales = "free_y") +
#' aes(fill = model, color = model)

plot_predictions <- function(data,
by = NULL,
x = "date",
interval_range = c(0, 50, 90)) {

# split truth data and forecasts in order to apply different filtering
truth_data <- data.table::as.data.table(data)[!is.na(observed)]
forecasts <- data.table::as.data.table(data)[!is.na(predicted)]

del_cols <-
colnames(truth_data)[!(colnames(truth_data) %in% c(by, "observed", x))]

truth_data <- unique(suppressWarnings(truth_data[, eval(del_cols) := NULL]))

# find out what type of predictions we have. convert sample based to
# interval range data

if (test_forecast_type_is_quantile(data)) {
forecasts <- quantile_to_interval(
forecasts,
keep_quantile_col = FALSE
)
} else if (test_forecast_type_is_sample(data)) {
forecasts <- sample_to_interval_long(
forecasts,
interval_range = interval_range,
keep_quantile_col = FALSE
)
}

# select appropriate boundaries and pivot wider
select <- forecasts$interval_range %in% setdiff(interval_range, 0)
intervals <- forecasts[select, ]

# delete quantile column in intervals if present. This is important for
# pivoting
if ("quantile_level" %in% names(intervals)) {
intervals[, quantile_level := NULL]
}

plot <- ggplot(data = data, aes(x = .data[[x]])) +
theme_scoringutils() +
ylab("True and predicted values")

if (nrow(intervals) != 0) {
# pivot wider and convert range to a factor
intervals <- data.table::dcast(intervals, ... ~ boundary,
value.var = "predicted")

# only plot interval ranges if there are interval ranges to plot
plot <- plot +
ggdist::geom_lineribbon(
data = intervals,
aes(
ymin = lower, ymax = upper,
# We use the fill_ramp aesthetic for this instead of the default fill
# because we want to keep fill to be able to use it for other
# variables
fill_ramp = factor(
interval_range,
levels = sort(unique(interval_range), decreasing = TRUE)
)
),
lwd = 0.4
) +
ggdist::scale_fill_ramp_discrete(
name = "interval_range",
# range argument was added to make sure that the line for the median
# and the ribbon don't have the same opacity, making the line
# invisible
range = c(0.15, 0.75)
)
}

# We could treat this step as part of ggdist::geom_lineribbon() but we treat
# it separately here to deal with the case when only the median is provided
# (in which case ggdist::geom_lineribbon() will fail)
if (0 %in% interval_range) {
select_median <-
forecasts$interval_range == 0 & forecasts$boundary == "lower"
median <- forecasts[select_median]

if (nrow(median) > 0) {
plot <- plot +
geom_line(
data = median,
mapping = aes(y = predicted),
lwd = 0.4
)
}
}

# add observed values
if (nrow(truth_data) > 0) {
plot <- plot +
geom_point(
data = truth_data,
show.legend = FALSE,
inherit.aes = FALSE,
aes(x = .data[[x]], y = observed),
color = "black",
size = 0.5
) +
geom_line(
data = truth_data,
inherit.aes = FALSE,
show.legend = FALSE,
aes(x = .data[[x]], y = observed),
linetype = 1,
color = "grey40",
lwd = 0.2
)
}

return(plot)
}

#' @title Make Rows NA in Data for Plotting
#'
#' @description
#' Filters the data and turns values into `NA` before the data gets passed to
#' [plot_predictions()]. The reason to do this is to this is that it allows to
#' 'filter' prediction and truth data separately. Any value that is NA will then
#' be removed in the subsequent call to [plot_predictions()].
#'
#' @inheritParams score
#' @param what character vector that determines which values should be turned
#' into `NA`. If `what = "truth"`, values in the column 'observed' will be
#' turned into `NA`. If `what = "forecast"`, values in the column 'prediction'
#' will be turned into `NA`. If `what = "both"`, values in both column will be
#' turned into `NA`.
#' @param ... logical statements used to filter the data
#' @return A data.table
#' @importFrom rlang enexprs
#' @keywords plotting
#' @export
#'
#' @examples
#' make_NA (
#' example_continuous,
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' )

make_NA <- function(data = NULL,
what = c("truth", "forecast", "both"),
...) {

assert_not_null(data = data)

data <- data.table::copy(data)
what <- match.arg(what)

# turn ... arguments into expressions
args <- enexprs(...)

vars <- NULL
if (what %in% c("forecast", "both")) {
vars <- c(vars, "predicted")
}
if (what %in% c("truth", "both")) {
vars <- c(vars, "observed")
}
for (expr in args) {
data <- data[eval(expr), eval(vars) := NA_real_]
}
return(data[])
}

#' @rdname make_NA
#' @keywords plotting
#' @export
make_na <- make_NA

#' @title Plot Interval Coverage
#'
Expand All @@ -514,7 +278,7 @@ make_na <- make_NA
#' Default is "model".
#' @return ggplot object with a plot of interval coverage
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data
#' facet_wrap facet_grid geom_polygon
#' facet_wrap facet_grid geom_polygon geom_line
#' @importFrom data.table dcast
#' @export
#' @examples
Expand Down Expand Up @@ -575,7 +339,7 @@ plot_interval_coverage <- function(coverage,
#' Default is "model".
#' @return ggplot object with a plot of interval coverage
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data aes
#' scale_y_continuous
#' scale_y_continuous geom_line
#' @importFrom data.table dcast
#' @export
#' @examples
Expand Down
21 changes: 0 additions & 21 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,6 @@ remotes::install_github("epiforecasts/scoringutils", dependencies = TRUE)

In this quick start guide we explore some of the functionality of the `scoringutils` package using quantile forecasts from the [ECDC forecasting hub](https://covid19forecasthub.eu/) as an example. For more detailed documentation please see the package vignettes, and individual function documentation.

### Plotting forecasts

As a first step to evaluating the forecasts we visualise them. For the purposes of this example here we make use of `plot_predictions()` to filter the available forecasts for a single model, and forecast date.

```{r, fig.width = 9, fig.height = 6}
example_quantile %>%
make_NA(what = "truth",
target_end_date >= "2021-07-15",
target_end_date < "2021-05-22"
) %>%
make_NA(what = "forecast",
model != "EuroCOVIDhub-ensemble",
forecast_date != "2021-06-28"
) %>%
plot_predictions(
x = "target_end_date",
by = c("target_type", "location")
) +
facet_wrap(target_type ~ location, ncol = 4, scales = "free")
```

### Scoring forecasts

Forecasts can be easily and quickly scored using the `score()` function. `score()` automatically tries to determine the `forecast_unit`, i.e. the set of columns that uniquely defines a single forecast, by taking all column names of the data into account. However, it is recommended to set the forecast unit manually by specifying the "forecast_unit" argument in `as_forecast()` as this may help to avoid errors. This will drop all columns that are neither part of the forecast unit nor part of the columns internally used by `scoringutils`. The function `as_forecast()` processes and validates the inputs.
Expand Down
25 changes: 0 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,6 @@ forecasting hub](https://covid19forecasthub.eu/) as an example. For more
detailed documentation please see the package vignettes, and individual
function documentation.

### Plotting forecasts

As a first step to evaluating the forecasts we visualise them. For the
purposes of this example here we make use of `plot_predictions()` to
filter the available forecasts for a single model, and forecast date.

``` r
example_quantile %>%
make_NA(what = "truth",
target_end_date >= "2021-07-15",
target_end_date < "2021-05-22"
) %>%
make_NA(what = "forecast",
model != "EuroCOVIDhub-ensemble",
forecast_date != "2021-06-28"
) %>%
plot_predictions(
x = "target_end_date",
by = c("target_type", "location")
) +
facet_wrap(target_type ~ location, ncol = 4, scales = "free")
```

![](man/figures/unnamed-chunk-4-1.png)<!-- -->

### Scoring forecasts

Forecasts can be easily and quickly scored using the `score()` function.
Expand Down
Loading

0 comments on commit 1fea902

Please sign in to comment.