Skip to content

Commit

Permalink
delete plot_predictions() and make_NA
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Feb 29, 2024
1 parent f02bcd2 commit a2bb455
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 353 deletions.
7 changes: 0 additions & 7 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ export(log_shift)
export(logs_binary)
export(logs_sample)
export(mad_sample)
export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_forecast)
export(overprediction)
Expand All @@ -61,7 +59,6 @@ export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparison)
export(plot_pit)
export(plot_predictions)
export(plot_quantile_coverage)
export(plot_score_table)
export(plot_wis)
Expand Down Expand Up @@ -129,7 +126,6 @@ importFrom(data.table,setDT)
importFrom(data.table,setattr)
importFrom(data.table,setcolorder)
importFrom(data.table,setnames)
importFrom(ggdist,geom_lineribbon)
importFrom(ggplot2,.data)
importFrom(ggplot2,`%+replace%`)
importFrom(ggplot2,aes)
Expand All @@ -143,9 +139,7 @@ importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_col)
importFrom(ggplot2,geom_histogram)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_polygon)
importFrom(ggplot2,geom_text)
importFrom(ggplot2,geom_tile)
Expand All @@ -168,7 +162,6 @@ importFrom(ggplot2,xlab)
importFrom(ggplot2,ylab)
importFrom(lifecycle,deprecated)
importFrom(methods,hasArg)
importFrom(rlang,enexprs)
importFrom(rlang,warn)
importFrom(scoringRules,crps_sample)
importFrom(scoringRules,dss_sample)
Expand Down
236 changes: 0 additions & 236 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,242 +266,6 @@ plot_heatmap <- function(scores,
return(plot)
}

#' @title Plot Predictions vs True Values
#'
#' @description
#' Make a plot of observed and predicted values
#'
#' @param data a data.frame that follows the same specifications outlined in
#' [score()]. To customise your plotting, you can filter your data using the
#' function [make_NA()].
#' @param by character vector with column names that denote categories by which
#' the plot should be stratified. If for example you want to have a facetted
#' plot, this should be a character vector with the columns used in facetting
#' (note that the facetting still needs to be done outside of the function call)
#' @param x character vector of length one that denotes the name of the variable
#' @param interval_range numeric vector indicating the interval ranges to plot.
#' If 0 is included in `interval_range`, the median prediction will be shown.
#' @return ggplot object with a plot of true vs predicted values
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual theme_light
#' @importFrom ggplot2 facet_wrap facet_grid aes geom_line .data geom_point
#' @importFrom data.table dcast
#' @importFrom ggdist geom_lineribbon
#' @export
#' @examples
#' library(ggplot2)
#' library(magrittr)
#'
#' example_continuous %>%
#' make_NA (
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' ) %>%
#' make_NA (
#' what = "forecast",
#' model != "EuroCOVIDhub-ensemble",
#' forecast_date != "2021-06-07"
#' ) %>%
#' plot_predictions (
#' x = "target_end_date",
#' by = c("target_type", "location"),
#' interval_range = c(0, 50, 90, 95)
#' ) +
#' facet_wrap(~ location + target_type, scales = "free_y") +
#' aes(fill = model, color = model)
#'
#' example_continuous %>%
#' make_NA (
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' ) %>%
#' make_NA (
#' what = "forecast",
#' forecast_date != "2021-06-07"
#' ) %>%
#' plot_predictions (
#' x = "target_end_date",
#' by = c("target_type", "location"),
#' interval_range = 0
#' ) +
#' facet_wrap(~ location + target_type, scales = "free_y") +
#' aes(fill = model, color = model)

plot_predictions <- function(data,
by = NULL,
x = "date",
interval_range = c(0, 50, 90)) {

# split truth data and forecasts in order to apply different filtering
truth_data <- data.table::as.data.table(data)[!is.na(observed)]
forecasts <- data.table::as.data.table(data)[!is.na(predicted)]

del_cols <-
colnames(truth_data)[!(colnames(truth_data) %in% c(by, "observed", x))]

truth_data <- unique(suppressWarnings(truth_data[, eval(del_cols) := NULL]))

# find out what type of predictions we have. convert sample based to
# interval range data

if (test_forecast_type_is_quantile(data)) {
forecasts <- quantile_to_interval(
forecasts,
keep_quantile_col = FALSE
)
} else if (test_forecast_type_is_sample(data)) {
forecasts <- sample_to_interval_long(
forecasts,
interval_range = interval_range,
keep_quantile_col = FALSE
)
}

# select appropriate boundaries and pivot wider
select <- forecasts$interval_range %in% setdiff(interval_range, 0)
intervals <- forecasts[select, ]

# delete quantile column in intervals if present. This is important for
# pivoting
if ("quantile_level" %in% names(intervals)) {
intervals[, quantile_level := NULL]
}

plot <- ggplot(data = data, aes(x = .data[[x]])) +
theme_scoringutils() +
ylab("True and predicted values")

if (nrow(intervals) != 0) {
# pivot wider and convert range to a factor
intervals <- data.table::dcast(intervals, ... ~ boundary,
value.var = "predicted")

# only plot interval ranges if there are interval ranges to plot
plot <- plot +
ggdist::geom_lineribbon(
data = intervals,
aes(
ymin = lower, ymax = upper,
# We use the fill_ramp aesthetic for this instead of the default fill
# because we want to keep fill to be able to use it for other
# variables
fill_ramp = factor(
interval_range,
levels = sort(unique(interval_range), decreasing = TRUE)
)
),
lwd = 0.4
) +
ggdist::scale_fill_ramp_discrete(
name = "interval_range",
# range argument was added to make sure that the line for the median
# and the ribbon don't have the same opacity, making the line
# invisible
range = c(0.15, 0.75)
)
}

# We could treat this step as part of ggdist::geom_lineribbon() but we treat
# it separately here to deal with the case when only the median is provided
# (in which case ggdist::geom_lineribbon() will fail)
if (0 %in% interval_range) {
select_median <-
forecasts$interval_range == 0 & forecasts$boundary == "lower"
median <- forecasts[select_median]

if (nrow(median) > 0) {
plot <- plot +
geom_line(
data = median,
mapping = aes(y = predicted),
lwd = 0.4
)
}
}

# add observed values
if (nrow(truth_data) > 0) {
plot <- plot +
geom_point(
data = truth_data,
show.legend = FALSE,
inherit.aes = FALSE,
aes(x = .data[[x]], y = observed),
color = "black",
size = 0.5
) +
geom_line(
data = truth_data,
inherit.aes = FALSE,
show.legend = FALSE,
aes(x = .data[[x]], y = observed),
linetype = 1,
color = "grey40",
lwd = 0.2
)
}

return(plot)
}

#' @title Make Rows NA in Data for Plotting
#'
#' @description
#' Filters the data and turns values into `NA` before the data gets passed to
#' [plot_predictions()]. The reason to do this is to this is that it allows to
#' 'filter' prediction and truth data separately. Any value that is NA will then
#' be removed in the subsequent call to [plot_predictions()].
#'
#' @inheritParams score
#' @param what character vector that determines which values should be turned
#' into `NA`. If `what = "truth"`, values in the column 'observed' will be
#' turned into `NA`. If `what = "forecast"`, values in the column 'prediction'
#' will be turned into `NA`. If `what = "both"`, values in both column will be
#' turned into `NA`.
#' @param ... logical statements used to filter the data
#' @return A data.table
#' @importFrom rlang enexprs
#' @keywords plotting
#' @export
#'
#' @examples
#' make_NA (
#' example_continuous,
#' what = "truth",
#' target_end_date >= "2021-07-22",
#' target_end_date < "2021-05-01"
#' )

make_NA <- function(data = NULL,
what = c("truth", "forecast", "both"),
...) {

assert_not_null(data = data)

data <- data.table::copy(data)
what <- match.arg(what)

# turn ... arguments into expressions
args <- enexprs(...)

vars <- NULL
if (what %in% c("forecast", "both")) {
vars <- c(vars, "predicted")
}
if (what %in% c("truth", "both")) {
vars <- c(vars, "observed")
}
for (expr in args) {
data <- data[eval(expr), eval(vars) := NA_real_]
}
return(data[])
}

#' @rdname make_NA
#' @keywords plotting
#' @export
make_na <- make_NA

#' @title Plot Interval Coverage
#'
Expand Down
40 changes: 0 additions & 40 deletions man/make_NA.Rd

This file was deleted.

70 changes: 0 additions & 70 deletions man/plot_predictions.Rd

This file was deleted.

0 comments on commit a2bb455

Please sign in to comment.