Skip to content

Commit

Permalink
no warning no error
Browse files Browse the repository at this point in the history
  • Loading branch information
sima-njf committed Dec 11, 2024
1 parent a2a5749 commit e99b9d9
Show file tree
Hide file tree
Showing 23 changed files with 123 additions and 407 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ Imports:
keras3,
dplyr,
ggplot2 (>= 3.4.0),
stats
stats,
reticulate
VignetteBuilder: knitr
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ importFrom(data.table,as.data.table)
importFrom(data.table,copy)
importFrom(data.table,dcast)
importFrom(data.table,melt)
importFrom(data.table,setnames)
importFrom(dplyr,mutate)
importFrom(dplyr,row_number)
importFrom(epiworldR,ModelSIRCONN)
importFrom(epiworldR,run)
importFrom(epiworldR,verbose_off)
importFrom(ggplot2,aes)
Expand All @@ -29,15 +31,13 @@ importFrom(ggplot2,geom_abline)
importFrom(ggplot2,geom_boxplot)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,labs)
importFrom(keras3,array_reshape)
importFrom(parallel,clusterEvalQ)
importFrom(parallel,clusterExport)
importFrom(parallel,makeCluster)
importFrom(parallel,mclapply)
importFrom(parallel,parLapply)
importFrom(parallel,stopCluster)
importFrom(stats,plogis)
importFrom(stats,predict)
importFrom(stats,qlogis)
importFrom(stats,rbeta)
importFrom(stats,rgamma)
6 changes: 1 addition & 5 deletions R/build_cnn_model.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Build a Convolutional Neural Network Model
#'
#' @description
#' Constructs and compiles a CNN model using the `keras` package.
#' Constructs and compiles a CNN model using the `keras3` package.
#'
#' @param input_shape Integer vector. The shape of the input data (excluding batch size).
#' @param output_units Integer. The number of output units (number of output variables).
Expand All @@ -21,9 +21,5 @@ build_cnn_model <- function(input_shape, output_units) {

model |> keras3::compile(optimizer = 'adam', loss = 'mse', metrics = 'accuracy')

# Save the model in native Keras format
# model$save('/home/u1418987/epiworld-benchmark-oct15/calibration/sir-cnn.keras')


return(model)
}
63 changes: 6 additions & 57 deletions R/calibrate_sir.R
Original file line number Diff line number Diff line change
@@ -1,82 +1,31 @@

#' Predict Parameters Using a CNN Model
#' Generates predictions for input test data using a pre-trained Convolutional Neural Network (CNN) model.
#'
#' This function loads a specific Keras CNN model based on the length of the input data and uses it to make predictions. The predicted values are returned as a `data.table` with standardized column names.
#' Calibrate SIR Model Predictions
#'
#' @param data A numeric matrix or array containing the input test data. The function determines which model to load based on the number of columns in `data` (expects either 30 or 60).
#' @description
#' Generates predictions for input test data using a pre-trained CNN model.
#'
#' @return A list containing:
#' \describe{
#' \item{pred}{A `data.table` of predicted values with the following columns:
#' \describe{
#' \item{\code{preval}}{Predicted prevalence.}
#' \item{\code{crate}}{Predicted case rate.}
#' \item{\code{ptran}}{Predicted transmission probability.}
#' \item{\code{prec}}{Predicted precision.}
#' }
#' }
#' }
#'
#' @details
#' The function determines which pre-trained CNN model to load based on the number of features (columns) in the input `data`. If `data` has 30 columns, it loads the `sir30-cnn.keras` model; if it has 60 columns, it loads the `sir60-cnn.keras` model. Ensure that the input data matches one of these expected formats to avoid errors.
#' @param data A numeric matrix or array containing the input test data.
#' @return A list containing a `data.table` of predicted values.
#' @export
# calibrate_sir <- function(data) {
# library(keras3)
# ans=preprocessing_data(data)
# a=length(ans)
# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1))
#
# if(a <=30){
# model <- keras3::load_model(
# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
# )
# }
# else{
# model <- keras3::load_model(
# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
# )
# }
# pred <- predict(model, x =ans ) |>
# data.table::as.data.table() |>
# data.table::setnames(c("preval","crate","ptran","prec"))
# pred$crate=qlogis(pred$crate)
#
# return(list(pred = pred))
# }
calibrate_sir <- function(data) {
# Load required libraries
library(tensorflow)
library(data.table)

# Preprocess the data
ans <- preprocessing_data(data)
a <- length(ans)
ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model
ans <- tensorflow::tf$reshape(ans, shape = c(1L, 1L, a, 1L))

# Determine model file path
model_path <- if (a <= 31) {
system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
} else {
system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
}

# Check if the model file exists
if (model_path == "") {
stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.")
}

# Load the model using tensorflow
model <- tensorflow::tf$keras$models$load_model(model_path)

# Make predictions
pred <- model$predict(ans) |>
data.table::as.data.table() |>
data.table::setnames(c("preval", "crate", "ptran", "prec"))


# Return predictions as a list
return(list(pred = pred))
}


146 changes: 37 additions & 109 deletions R/dataprep.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
#'
#' @param m An `epiworldR` model object.
#' @param max_days The maximum number of days to consider for data preparation. Defaults to 50.
#'
#' @return A reshaped array of processed data suitable for TensorFlow models. If an error occurs, it returns the error object.
#' @return A reshaped array of processed data suitable for TensorFlow models.
#' If an error occurs, it returns the error object.
#'
#' @export

prepare_data <- function(m, max_days =max_days) {
prepare_data <- function(m, max_days = max_days) {

err <- tryCatch({
ans <- list(
Expand All @@ -16,139 +15,68 @@ prepare_data <- function(m, max_days =max_days) {
gentime = epiworldR::plot_generation_time(m, plot = FALSE)
)

# Filling
# Convert all elements to data.table
ans <- lapply(ans, data.table::as.data.table)

# Replacing NaN and NAs with the previous value
# in each element in the list
# Replace NA values with the last observed value
ans$repnum$avg <- data.table::nafill(ans$repnum$avg, type = "locf")

# Replace NA values with the last observed value
ans$gentime$avg <- data.table::nafill(ans$gentime$avg, type = "locf")
# Replace NA values in repnum$avg and gentime$avg with the last observed value
ans[["repnum"]][, avg := data.table::nafill(avg, type = "locf")]
ans[["gentime"]][, avg := data.table::nafill(avg, type = "locf")]

# Filter up to max_days
ans[["repnum"]] <- ans[["repnum"]][date <= max_days, ]
ans[["gentime"]] <- ans[["gentime"]][date <= max_days, ]

# Filtering up to max_days
ans$repnum <- ans$repnum[ans$repnum$date <= max_days,]
ans$gentime <- ans$gentime[ans$gentime$date <= max_days,]
ans$incidence <- ans$incidence[as.integer(rownames(ans$incidence)) <= (max_days + 1),]
# incidence is indexed by row number since date is not explicitly in that data
# We assume the first row represents day 0, hence (max_days + 1) rows total.
ans[["incidence"]] <- ans[["incidence"]][as.integer(.I) <= (max_days + 1), ]

# Reference table for merging
# ndays <- epiworldR::get_ndays(m)
# Create a reference table for merging
ref_table <- data.table::data.table(date = 0:max_days)

ref_table <- data.table::data.table(
date = 0:max_days
)

# Replace the $ with the [[ ]] to avoid the warning in the next
# two lines
# Merge repnum and gentime with the reference table to ensure consistent length
ans[["repnum"]] <- data.table::merge.data.table(
ref_table, ans[["repnum"]], by = "date", all.x = TRUE
)
ans[["gentime"]] <- data.table::merge.data.table(
ref_table, ans[["gentime"]], by = "date", all.x = TRUE
)


# Generating the data.table with necessary columns
# Create a data.table with all required columns
ans <- data.table::data.table(
infected = ans[["incidence"]][["Infected"]],
recovered = ans[["incidence"]][["Recovered"]],
repnum = ans[["repnum"]][["avg"]],
gentime = ans[["gentime"]][["avg"]],
repnum_sd = ans[["repnum"]][["sd"]],
infected = ans[["incidence"]][["Infected"]],
recovered = ans[["incidence"]][["Recovered"]],
repnum = ans[["repnum"]][["avg"]],
gentime = ans[["gentime"]][["avg"]],
repnum_sd = ans[["repnum"]][["sd"]],
gentime_sd = ans[["gentime"]][["sd"]]
)

# Replace NA values with the last observed value for all columns
# Replace NA values in all relevant columns using locf
nafill_cols <- c("infected", "recovered", "repnum", "gentime", "repnum_sd", "gentime_sd")

for (col in nafill_cols) {
ans[[col]] <- data.table::nafill(ans[[col]], type = "locf")
}

# Return ans as processed data
ans
}, error = function(e) e)

# If there is an error, return NULL
# If there was an error, return it
if (inherits(err, "error")) {
return(err)
}

# Returning without the first observation (which is mostly zero)
dprep <- t(diff(as.matrix(ans[-1,])))

ans <- array(dim = c(1, dim(dprep)))
ans[1,,] <- dprep
abm_hist_feat <- ans
# Remove the first observation (often zero) and take differences
# ans is now a data.table with rows representing days
# ans[-1, ] removes the first row
dprep <- t(diff(as.matrix(err[-1,])))

array_reshape(
abm_hist_feat,
dim = c(1, dim(dprep))
)
# Construct a 3D array with shape (1, n_features, n_timesteps)
# Here n_features = number of variables (rows of dprep after transpose)
# and n_timesteps = number of columns (days-1)
ans_array <- array(dim = c(1, dim(dprep)[1], dim(dprep)[2]))
ans_array[1,,] <- dprep

# Reshape for TensorFlow input using keras3 (adjust if using another keras interface)
keras3::array_reshape(ans_array, dim = c(1, dim(dprep)))
}



#
# prepare_data_infections_only <- function(m, ...) {
# UseMethod("prepare_data_infectios_only")
# }
#
# prepare_data_infections_only.epiworld_model <- function(m, ...) {
# ans <- epiworldR::plot_incidence(m, plot = FALSE) |>
# data.table::as.data.table()
#
# prepare_data_infections_only.data.table(
# m = ans,
# ...
# )
# }
#
# prepare_data_infections_only.default <- function(m, max_days = 50, ...) {
#
# err <- tryCatch({
# ans <- list(
# incidence = data.table(Infected=m)
# )
#
# # Replacing NaN and NAs with the previous value
# # in each element in the list
# # Filtering up to max_days
# ans$incidence <- ans$incidence[as.integer(rownames(ans$incidence)) <= (max_days + 1),]
#
# # Reference table for merging
# # ndays <- epiworldR::get_ndays(m)
# ref_table <- data.table::data.table(
# date = 0:max_days
# )
#
# # Generating the arrays
# ans <- data.table::data.table(
# infected = ans[["incidence"]][["Infected"]]
# )
#
# # Filling NAs with last obs
# ans[, "infected" := data.table::nafill(.SD[[1]], "locf"),
# .SDcols = "infected"]
#
# }, error = function(e) e)
#
# # If there is an error, return NULL
# if (inherits(err, "error")) {
# return(err)
# }
#
# # Returning without the first observation (which is mostly zero)
# dprep <- t(diff(as.matrix(ans[-1,])))
#
# ans <- array(dim = c(1, dim(dprep)))
# ans[1,,] <- dprep
# abm_hist_feat <- ans
#
# array_reshape(
# abm_hist_feat,
# dim = c(1, dim(dprep))
# )
#
# }
2 changes: 1 addition & 1 deletion R/evaluate_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' }
#' @export
evaluate_model <- function(model, test_data, theta) {
pred <- predict(model, x = test_data$x) |>
pred <- model$predict(test_data$x) |>
data.table::as.data.table() |>
data.table::setnames(colnames(theta))

Expand Down
2 changes: 1 addition & 1 deletion R/filter_non_null.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
filter_non_null <- function(matrices, theta) {
is_not_null <- intersect(
which(!sapply(matrices, inherits, what = "error")),
which(!sapply(matrices, \(x) any(is.na(x))))
which(!sapply(matrices, function(x) any(is.na(x))))
)

matrices <- matrices[is_not_null]
Expand Down
3 changes: 1 addition & 2 deletions R/filter_non_null_infected.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
#' The function then returns the filtered matrices and their count.
#'
#' @export

filter_non_null_infected <- function(matrices) {
is_not_null <- intersect(
which(!sapply(matrices, inherits, what = "error")),
which(!sapply(matrices, \(x) any(is.na(x))))
which(!sapply(matrices, function(x) any(is.na(x))))
)

matrices <- matrices[is_not_null]
Expand Down
9 changes: 4 additions & 5 deletions R/generate_theta.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
#'
#' @param N Integer. The number of parameter sets to generate.
#' @param n Integer. The population size for each simulation.
#' @importFrom stats plogis predict qlogis rbeta rgamma
#' @importFrom stats plogis qlogis rbeta rgamma
#' @return A data.table containing the generated parameters.
#' @export
generate_theta <- function(N, n) {
library(data.table)
set.seed(1231)
theta <- data.table::data.table(
preval = sample((100:2000) / n, N, TRUE),
crate = rgamma(N, 5, 1), # Mean 5
ptran = rbeta(N, 3, 7), # Mean 0.3
prec = rbeta(N, 10, 10) # Mean 0.5
crate = stats::rgamma(N, 5, 1),
ptran = stats::rbeta(N, 3, 7),
prec = stats::rbeta(N, 10, 10)
)
return(theta)
}
Loading

0 comments on commit e99b9d9

Please sign in to comment.