Skip to content

Commit

Permalink
new change in calling directory in calibrate_sir
Browse files Browse the repository at this point in the history
  • Loading branch information
sima-njf committed Dec 10, 2024
1 parent 42ee35a commit 2ed39c8
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions R/calibrate_sir.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,62 +21,62 @@
#' @details
#' The function determines which pre-trained CNN model to load based on the number of features (columns) in the input `data`. If `data` has 30 columns, it loads the `sir30-cnn.keras` model; if it has 60 columns, it loads the `sir60-cnn.keras` model. Ensure that the input data matches one of these expected formats to avoid errors.
#' @export
# calibrate_sir <- function(data) {
# library(keras3)
# ans=preprocessing_data(data)
# a=length(ans)
# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1))
#
# if(a <=30){
# model <- keras3::load_model(
# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
# )
# }
# else{
# model <- keras3::load_model(
# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
# )
# }
# pred <- predict(model, x =ans ) |>
# data.table::as.data.table() |>
# data.table::setnames(c("preval","crate","ptran","prec"))
# pred$crate=qlogis(pred$crate)
#
# return(list(pred = pred))
# }
calibrate_sir <- function(data) {
# Load required libraries
library(tensorflow)
library(data.table)

# Preprocess the data
ans <- preprocessing_data(data)
a <- length(ans)
ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model
library(keras3)
ans=preprocessing_data(data)
a=length(ans)
ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1))

# Determine model file path
model_path <- if (a <= 31) {
system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
} else {
system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
if(a <=30){
model <- keras3::load_model(
system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
)
}

# Check if the model file exists
if (model_path == "") {
stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.")
else{
model <- keras3::load_model(
system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
)
}

# Load the model using tensorflow
model <- tensorflow::tf$keras$models$load_model(model_path)

# Make predictions
pred <- model$predict(ans) |>
pred <- predict(model, x =ans ) |>
data.table::as.data.table() |>
data.table::setnames(c("preval", "crate", "ptran", "prec"))

data.table::setnames(c("preval","crate","ptran","prec"))
pred$crate=qlogis(pred$crate)

# Return predictions as a list
return(list(pred = pred))
}

# calibrate_sir <- function(data) {
# # Load required libraries
# library(tensorflow)
# library(data.table)
#
# # Preprocess the data
# ans <- preprocessing_data(data)
# a <- length(ans)
# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model
#
# # Determine model file path
# model_path <- if (a <= 31) {
# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
# } else {
# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
# }
#
# # Check if the model file exists
# if (model_path == "") {
# stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.")
# }
#
# # Load the model using tensorflow
# model <- tensorflow::tf$keras$models$load_model(model_path)
#
# # Make predictions
# pred <- model$predict(ans) |>
# data.table::as.data.table() |>
# data.table::setnames(c("preval", "crate", "ptran", "prec"))
#
#
# # Return predictions as a list
# return(list(pred = pred))
# }
#

0 comments on commit 2ed39c8

Please sign in to comment.