Skip to content

Commit

Permalink
fixed forecast verification bug and added invariance & lag warning
Browse files Browse the repository at this point in the history
- when training window/data has no variance, error is added as bsts and lmSpike package don't work then
- added error when added lagged dependent variable values exceed training window
- fixed error with evaluation plan function and plot slice order
  • Loading branch information
petersen-f committed Nov 29, 2023
1 parent 5c52d1a commit 5b2d412
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 34 deletions.
92 changes: 60 additions & 32 deletions R/predictiveAnalytics.R
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,21 @@ lagit <- function(a,k) {

dataControl <- jaspResults[["predanResults"]][["predanBounds"]]$object[[1]]

# throw error when lags are larger than training window as lags can't be computed
if(options$featEngLags > options$resampleInitialTraining){
errorPlot <- createJaspPlot(dependencies= c("featEngLags","resampleInitialTraining"))
errorPlot$setError(gettext(paste(
"The length of the training window is shorter than the number of lags selected in the 'Feature Engineering' section.",
"This makes it impossible to compute all the values of the lagged dependent variable as there is too little data for training",
"Either increase the training window size or reduce the number of lags."
)))
jaspResults[["predanMainContainer"]][["cvContainer"]][["errorPlot"]] <- errorPlot
return()
}

if(is.null(jaspResults[["predanResults"]][["cvPlanState"]])){
cvPlanState <- createJaspState(dependencies = c(.modelDependencies(),.forecastVeriDependencies()))
cvPlanState$object <- .crossValidationPlanHelper(data = na.omit(dataControl),
cvPlanState$object <- .crossValidationPlanHelper(data = dataControl,
initial = options$resampleInitialTraining,
assess = options$resampleForecastHorizon,
cumulative = options$resampleCumulativeCheck,
Expand All @@ -920,7 +932,9 @@ lagit <- function(a,k) {
cvPlot$plotObject <- .cvPlanPlot(data = dataControl,
cvPlan = jaspResults[["predanResults"]][["cvPlanState"]]$object,
equal_distance = options$resamplePlanPlotEqualDistance,
maxSlices = options$resamplePlanPlotMaxPlots,ncol=1)
maxSlices = options$resamplePlanPlotMaxPlots,
ncol=1,
from = options$"resampleSliceStart")
jaspResults[["predanMainContainer"]][["cvContainer"]][["cvPlanPlot"]] <- cvPlot
}

Expand All @@ -931,7 +945,7 @@ lagit <- function(a,k) {
.forecastVeriDependencies(),
"selectedModels"))

dataEng <- na.omit(jaspResults[["predanResults"]][["featureEngState"]]$object)
dataEng <- jaspResults[["predanResults"]][["featureEngState"]]$object

print(paste('selected models are: ', options$selectedModels))
modelList <- .createModelListHelper(dataEng,unlist(options$selectedModels))
Expand Down Expand Up @@ -978,49 +992,60 @@ lagit <- function(a,k) {
return()
}


.crossValidationPlanHelper <- function(data,initial = 20,assess = 10,cumulative = TRUE,skip = 10,lag = 0, max_slice = 5,from = c("head","tail")){
from <- match.arg(from)
.crossValidationPlanHelper <- function(data,
initial = 5,
assess = 1,
skip = 1,
lag = 0,
cumulative = FALSE,
max_slice = 5,
from = c('head', 'tail')) {
n <- nrow(data)

if(from == "tail"){
stops <- n - seq(assess, (n - initial), by = skip)
if (from == 'head') {
stops <- seq(initial + lag, (n - assess), by = skip)
if (!cumulative) {
starts <- stops - initial + 1
} else {
starts <- rep(1, length(stops))
# only start training frame after lag are created so no NA values
starts <- rep(lag + 1, length(stops))
}


in_ind <- mapply(seq, starts, stops, SIMPLIFY = FALSE)
out_ind <- mapply(seq, stops + 1 , stops + assess, SIMPLIFY = FALSE)
merge_lists <- function(a, b)
list(analysis = a, assessment = b)
indices <-
mapply(merge_lists, in_ind, out_ind, SIMPLIFY = FALSE)
names(indices) <- paste0("slice", 1:length(indices))
indices <- head(indices, max_slice)
} else {
stops <- seq(initial, (n - assess), by = skip + 1)
starts <- if (!cumulative) {
stops - initial + 1
stops <- n - seq(assess, (n - initial - lag), by = skip)
if (!cumulative) {
starts <- stops - initial + 1
} else {
starts <- rep(1, length(stops))
# only start training frame after lag are created so no NA values
starts <- rep(lag + 1, length(stops))
}

in_ind <- mapply(seq, starts, stops, SIMPLIFY = FALSE)
out_ind <- mapply(seq, stops + 1 , stops + assess, SIMPLIFY = FALSE)
merge_lists <- function(a, b) list(analysis = a, assessment = b)
indices <- mapply(merge_lists, in_ind, out_ind, SIMPLIFY = FALSE)
names(indices) <- paste0("slice", 1:length(indices))
indices <- head(indices, max_slice)
#reverse order so first slice includes most recent observation
#that way BMA weights will be based on most recent data and not beginning
indices <- rev(indices)
}
in_ind <- mapply(seq, starts, stops, SIMPLIFY = FALSE)
out_ind <- mapply(seq, stops + 1 - lag, stops + assess, SIMPLIFY = FALSE)
merge_lists <- function(a, b) list(analysis = a, assessment = b)
indices <- mapply(merge_lists, in_ind, out_ind, SIMPLIFY = FALSE)
names(indices) <- paste0("slice",length(indices):1)

if(from == "tail"){
indices <- rev(head(indices,max_slice))
}
else {
names(indices) <- rev(names(indices))
indices <- head(indices,max_slice)
}

return(indices)

}

.cvPlanPlot <- function(data,cvPlan,maxSlices=2,equal_distance = T,...){
.cvPlanPlot <- function(data, cvPlan, maxSlices=2, equal_distance = T, ncol, from){

t_var <- ifelse(equal_distance,"tt","time")

# reverse so most recent/oldest slices are on top depending on resampleSliceStart
if(from == "tail") cvPlan <- rev(cvPlan)
data$X <- 1:nrow(data)
dataPlot <- dplyr::bind_rows(.id = "id",lapply(head(cvPlan,maxSlices), function(x) data.frame( tt = c(x$analysis,x$assessment),
type = rep(c("Analysis","Assessment"),c(length(x$analysis),length(x$assessment)) ))))
Expand All @@ -1037,7 +1062,7 @@ lagit <- function(a,k) {
ggplot2::theme(plot.margin = ggplot2::margin(t = 3, r = 12, b = 0, l = 1)) +
ggplot2::scale_y_continuous(name = "Y",breaks = yBreaks,limits = range(yBreaks)) +
ggplot2::scale_x_continuous(name = "Time",breaks = xBreaks,limits = range(xBreaks)) +
ggplot2::facet_wrap(facets = "id",scales = "free",...) +jaspGraphs::geom_rangeframe() +
ggplot2::facet_wrap(facets = "id",scales = "free",ncol = ncol) +jaspGraphs::geom_rangeframe() +
ggplot2::theme(plot.margin = ggplot2::margin(t = 3, r = 12, b = 0, l = 1)) +
ggplot2::theme(panel.grid = ggplot2::theme_bw()$panel.grid,
panel.background = ggplot2::element_rect(fill = "white"),
Expand Down Expand Up @@ -1097,6 +1122,8 @@ lagit <- function(a,k) {
# wrapper that performs preprocessing, trains model and does prediction
.predAnModelFit <- function(trainData, formula, method, fit,predictFuture = F, testData,model_args = list(),preProList = NULL,...){

if((var(trainData$y) == 0 || is.na(var(trainData$y))))
stop(gettextf('Attempted to fit prediction model %s, but this model requires that the variance of the dependent variable is larger than zero. Either increase the training window or choose a different prediction model.',method))

lags <- sum(grepl("y_lag",labels(terms(formula))))

Expand Down Expand Up @@ -1523,6 +1550,7 @@ lagit <- function(a,k) {
#order slices properly so plot shows correctl
slicesLevels <- unique(dataPlot$slice)
slicesLevels <- slicesLevels[order(nchar(slicesLevels))]
if(options$resampleSliceStart == "tail") slicesLevels <- rev(slicesLevels)
dataPlot$slice <- factor(dataPlot$slice,levels = slicesLevels)

#slicesInclude <- ifelse(options$resampleSliceStart == 'head',head(slicesLevels,maxSlices),tail(slicesLevels,maxSlices))
Expand Down
6 changes: 4 additions & 2 deletions inst/qml/predictiveAnalytics.qml
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,14 @@ Form
//Layout.columnSpan: 1
IntegerField{
name: "resampleForecastHorizon"
min: 2
id: "resampleForecastHorizon"
label: qsTr("Prediction window")
defaultValue: Math.floor((dataSetModel.rowCount() / 5)*0.6)
}
IntegerField{
name: "resampleInitialTraining"
min: 2
id: "resampleInitialTraining"
label: qsTr("Training window")
defaultValue: Math.floor((dataSetModel.rowCount() / 5)*1.4)
Expand Down Expand Up @@ -705,12 +707,12 @@ Form
RadioButton
{
value: "last"
checked: true
checked: !resampleCumulativeCheck.checked
label: qsTr("Last")
childrenOnSameRow: true
IntegerField{name: "futurePredTrainingPoints"; afterLabel: qsTr("data points"); defaultValue: resampleInitialTraining.value}
}
RadioButton{name: "all"; label: qsTr("All data points")}
RadioButton{name: "all"; label: qsTr("All data points"); checked: resampleCumulativeCheck.checked }

}
}
Expand Down

0 comments on commit 5b2d412

Please sign in to comment.