Skip to content

Commit

Permalink
no warning no error
Browse files Browse the repository at this point in the history
  • Loading branch information
sima-njf committed Dec 11, 2024
1 parent e99b9d9 commit 7d6596f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 81 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export(run_simulations)
export(simulate_calibrate_sir)
export(split_data)
export(train_model)
import(epiworldR)
importFrom(data.table,as.data.table)
importFrom(data.table,copy)
importFrom(data.table,dcast)
Expand Down
79 changes: 44 additions & 35 deletions R/dataprep.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#' @param max_days The maximum number of days to consider for data preparation. Defaults to 50.
#' @return A reshaped array of processed data suitable for TensorFlow models.
#' If an error occurs, it returns the error object.
#'@import epiworldR
#'
#' @export
prepare_data <- function(m, max_days = max_days) {

prepare_data <- function(m, max_days =max_days) {

err <- tryCatch({
ans <- list(
Expand All @@ -15,68 +17,75 @@ prepare_data <- function(m, max_days = max_days) {
gentime = epiworldR::plot_generation_time(m, plot = FALSE)
)

# Convert all elements to data.table
# Filling
ans <- lapply(ans, data.table::as.data.table)

# Replace NA values in repnum$avg and gentime$avg with the last observed value
ans[["repnum"]][, avg := data.table::nafill(avg, type = "locf")]
ans[["gentime"]][, avg := data.table::nafill(avg, type = "locf")]
# Replacing NaN and NAs with the previous value
# in each element in the list
# Replace NA values with the last observed value
ans$repnum$avg <- data.table::nafill(ans$repnum$avg, type = "locf")

# Replace NA values with the last observed value
ans$gentime$avg <- data.table::nafill(ans$gentime$avg, type = "locf")

# Filter up to max_days
ans[["repnum"]] <- ans[["repnum"]][date <= max_days, ]
ans[["gentime"]] <- ans[["gentime"]][date <= max_days, ]

# incidence is indexed by row number since date is not explicitly in that data
# We assume the first row represents day 0, hence (max_days + 1) rows total.
ans[["incidence"]] <- ans[["incidence"]][as.integer(.I) <= (max_days + 1), ]
# Filtering up to max_days
ans$repnum <- ans$repnum[ans$repnum$date <= max_days,]
ans$gentime <- ans$gentime[ans$gentime$date <= max_days,]
ans$incidence <- ans$incidence[as.integer(rownames(ans$incidence)) <= (max_days + 1),]

# Create a reference table for merging
ref_table <- data.table::data.table(date = 0:max_days)
# Reference table for merging
# ndays <- epiworldR::get_ndays(m)

ref_table <- data.table::data.table(
date = 0:max_days
)

# Merge repnum and gentime with the reference table to ensure consistent length
# Replace the $ with the [[ ]] to avoid the warning in the next
# two lines
ans[["repnum"]] <- data.table::merge.data.table(
ref_table, ans[["repnum"]], by = "date", all.x = TRUE
)
ans[["gentime"]] <- data.table::merge.data.table(
ref_table, ans[["gentime"]], by = "date", all.x = TRUE
)

# Create a data.table with all required columns

# Generating the data.table with necessary columns
ans <- data.table::data.table(
infected = ans[["incidence"]][["Infected"]],
recovered = ans[["incidence"]][["Recovered"]],
repnum = ans[["repnum"]][["avg"]],
gentime = ans[["gentime"]][["avg"]],
repnum_sd = ans[["repnum"]][["sd"]],
infected = ans[["incidence"]][["Infected"]],
recovered = ans[["incidence"]][["Recovered"]],
repnum = ans[["repnum"]][["avg"]],
gentime = ans[["gentime"]][["avg"]],
repnum_sd = ans[["repnum"]][["sd"]],
gentime_sd = ans[["gentime"]][["sd"]]
)

# Replace NA values in all relevant columns using locf
# Replace NA values with the last observed value for all columns
nafill_cols <- c("infected", "recovered", "repnum", "gentime", "repnum_sd", "gentime_sd")

for (col in nafill_cols) {
ans[[col]] <- data.table::nafill(ans[[col]], type = "locf")
}

# Return ans as processed data
ans
}, error = function(e) e)

# If there was an error, return it
# If there is an error, return NULL
if (inherits(err, "error")) {
return(err)
}

# Remove the first observation (often zero) and take differences
# ans is now a data.table with rows representing days
# ans[-1, ] removes the first row
dprep <- t(diff(as.matrix(err[-1,])))
# Returning without the first observation (which is mostly zero)
dprep <- t(diff(as.matrix(ans[-1,])))

# Construct a 3D array with shape (1, n_features, n_timesteps)
# Here n_features = number of variables (rows of dprep after transpose)
# and n_timesteps = number of columns (days-1)
ans_array <- array(dim = c(1, dim(dprep)[1], dim(dprep)[2]))
ans_array[1,,] <- dprep
ans <- array(dim = c(1, dim(dprep)))
ans[1,,] <- dprep
abm_hist_feat <- ans

tensorflow::array_reshape(
abm_hist_feat,
dim = c(1, dim(dprep))
)

# Reshape for TensorFlow input using keras3 (adjust if using another keras interface)
keras3::array_reshape(ans_array, dim = c(1, dim(dprep)))
}

60 changes: 15 additions & 45 deletions R/run_simulations.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,23 @@
#' @return A list containing the simulation results as matrices.
#' @export
run_simulations <- function(N, n, ndays, ncores, theta, seeds) {
os_type <- .Platform$OS.type
matrices <- parallel::mclapply(1:N, FUN = function(i) {
set.seed(seeds[i])
m <- epiworldR:: ModelSIRCONN(
"mycon",
prevalence = theta$preval[i],
contact_rate = theta$crate[i],
transmission_rate = theta$ptran[i],
recovery_rate = theta$prec[i],
n = n
)

if (os_type == "windows") {
cl <- parallel::makeCluster(ncores)
on.exit(parallel::stopCluster(cl))
verbose_off(m)
run(m, ndays = ndays)
ans <- prepare_data(m,max_days=ndays)

parallel::clusterExport(cl, varlist = c("theta", "n", "ndays", "seeds", "prepare_data"), envir = environment())
parallel::clusterEvalQ(cl, {
# Load needed packages on workers if required (if not already loaded)
# library(epiworldR) # Not allowed, we rely on namespaces now
})

matrices <- parallel::parLapply(cl, 1:N, function(i) {
set.seed(seeds[i])
m <- epiworldR::ModelSIRCONN(
"mycon",
prevalence = theta$preval[i],
contact_rate = theta$crate[i],
transmission_rate = theta$ptran[i],
recovery_rate = theta$prec[i],
n = n
)

epiworldR::verbose_off(m)
epiworldR::run(m, ndays = ndays)
ans <- prepare_data(m, max_days = ndays)
return(ans)
})

} else {
matrices <- parallel::mclapply(1:N, function(i) {
set.seed(seeds[i])
m <- epiworldR::ModelSIRCONN(
"mycon",
prevalence = theta$preval[i],
contact_rate = theta$crate[i],
transmission_rate = theta$ptran[i],
recovery_rate = theta$prec[i],
n = n
)

epiworldR::verbose_off(m)
epiworldR::run(m, ndays = ndays)
ans <- prepare_data(m, max_days = ndays)
return(ans)
}, mc.cores = ncores)
}
return(ans)
}, mc.cores = ncores)

return(matrices)
}
2 changes: 1 addition & 1 deletion vignettes/calibrate.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ The function's arguments are:
The first step in the function is to generate the parameter sets (`theta`) and random seeds for the simulation:

```{r}
N=2e4
N=10
n=5000
theta <- generate_theta(N, n)
head(theta,5)
Expand Down

0 comments on commit 7d6596f

Please sign in to comment.