From 2ed39c8e26abb79347c19131478510ca1460715f Mon Sep 17 00:00:00 2001 From: Sima Najafzadehkhoei Date: Tue, 10 Dec 2024 15:12:38 -0700 Subject: [PATCH] new change in calling directory in calibrate_sir --- R/calibrate_sir.R | 100 +++++++++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/R/calibrate_sir.R b/R/calibrate_sir.R index fc84081..cba7ce5 100644 --- a/R/calibrate_sir.R +++ b/R/calibrate_sir.R @@ -21,62 +21,62 @@ #' @details #' The function determines which pre-trained CNN model to load based on the number of features (columns) in the input `data`. If `data` has 30 columns, it loads the `sir30-cnn.keras` model; if it has 60 columns, it loads the `sir60-cnn.keras` model. Ensure that the input data matches one of these expected formats to avoid errors. #' @export -# calibrate_sir <- function(data) { -# library(keras3) -# ans=preprocessing_data(data) -# a=length(ans) -# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) -# -# if(a <=30){ -# model <- keras3::load_model( -# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") -# ) -# } -# else{ -# model <- keras3::load_model( -# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") -# ) -# } -# pred <- predict(model, x =ans ) |> -# data.table::as.data.table() |> -# data.table::setnames(c("preval","crate","ptran","prec")) -# pred$crate=qlogis(pred$crate) -# -# return(list(pred = pred)) -# } calibrate_sir <- function(data) { - # Load required libraries - library(tensorflow) - library(data.table) - - # Preprocess the data - ans <- preprocessing_data(data) - a <- length(ans) - ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model + library(keras3) + ans=preprocessing_data(data) + a=length(ans) + ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) - # Determine model file path - model_path <- if (a <= 31) { - system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") - } else { - system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") + if(a <=30){ + model <- keras3::load_model( + system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") + ) } - - # Check if the model file exists - if (model_path == "") { - stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.") + else{ + model <- keras3::load_model( + system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") + ) } - - # Load the model using tensorflow - model <- tensorflow::tf$keras$models$load_model(model_path) - - # Make predictions - pred <- model$predict(ans) |> + pred <- predict(model, x =ans ) |> data.table::as.data.table() |> - data.table::setnames(c("preval", "crate", "ptran", "prec")) - + data.table::setnames(c("preval","crate","ptran","prec")) + pred$crate=qlogis(pred$crate) - # Return predictions as a list return(list(pred = pred)) } - +# calibrate_sir <- function(data) { +# # Load required libraries +# library(tensorflow) +# library(data.table) +# +# # Preprocess the data +# ans <- preprocessing_data(data) +# a <- length(ans) +# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model +# +# # Determine model file path +# model_path <- if (a <= 31) { +# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") +# } else { +# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") +# } +# +# # Check if the model file exists +# if (model_path == "") { +# stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.") +# } +# +# # Load the model using tensorflow +# model <- tensorflow::tf$keras$models$load_model(model_path) +# +# # Make predictions +# pred <- model$predict(ans) |> +# data.table::as.data.table() |> +# data.table::setnames(c("preval", "crate", "ptran", "prec")) +# +# +# # Return predictions as a list +# return(list(pred = pred)) +# } +#