Skip to content

Commit

Permalink
Add set_rand_engine function and update constructor to extract rand_e…
Browse files Browse the repository at this point in the history
…ngine from model
  • Loading branch information
apulsipher committed Oct 23, 2024
1 parent 7cf7281 commit f48be56
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 18 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ S3method(set_observed_data,epiworld_lfmcmc)
S3method(set_par_names,epiworld_lfmcmc)
S3method(set_param,epiworld_model)
S3method(set_proposal_fun,epiworld_lfmcmc)
S3method(set_rand_engine_lfmcmc,epiworld_lfmcmc)
S3method(set_simulation_fun,epiworld_lfmcmc)
S3method(set_stats_names,epiworld_lfmcmc)
S3method(set_summary_fun,epiworld_lfmcmc)
Expand Down Expand Up @@ -209,6 +210,7 @@ export(set_prob_recovery)
export(set_prob_recovery_fun)
export(set_prob_recovery_ptr)
export(set_proposal_fun)
export(set_rand_engine_lfmcmc)
export(set_recovery_enhancer)
export(set_recovery_enhancer_fun)
export(set_recovery_enhancer_ptr)
Expand Down
21 changes: 19 additions & 2 deletions R/LFMCMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
#' @aliases epiworld_lfmcmc
#' @details
#' TODO: Detail LFMCMC
#' TODO: Add params
#' @returns
#' - The `LFMCMC`function returns a model of class [epiworld_lfmcmc].
#' @examples
#' model_lfmcmc <- LFMCMC()
#' @export
LFMCMC <- function() {
LFMCMC <- function(model) {
if (!inherits(model, "epiworld_model"))
stop("model should be of class 'epiworld_model'. It is of class ", class(model))

structure(
LFMCMC_cpp(),
LFMCMC_cpp(model),
class = c("epiworld_lfmcmc")
)
}
Expand Down Expand Up @@ -96,6 +100,19 @@ set_kernel_fun.epiworld_lfmcmc <- function(lfmcmc, fun) {
invisible(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @param eng The rand engine
#' @returns The lfmcmc model with the engine set
#' @export
set_rand_engine_lfmcmc <- function(lfmcmc, eng) UseMethod("set_rand_engine_lfmcmc")

#' @export
set_rand_engine_lfmcmc.epiworld_lfmcmc <- function(lfmcmc, eng) {
set_rand_engine_lfmcmc_cpp(lfmcmc, eng)
invisible(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @param s The rand engine seed
Expand Down
8 changes: 6 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ ModelSEIRMixing_cpp <- function(name, n, prevalence, contact_rate, transmission_
.Call(`_epiworldR_ModelSEIRMixing_cpp`, name, n, prevalence, contact_rate, transmission_rate, incubation_days, recovery_rate, contact_matrix)
}

LFMCMC_cpp <- function() {
.Call(`_epiworldR_LFMCMC_cpp`)
LFMCMC_cpp <- function(m) {
.Call(`_epiworldR_LFMCMC_cpp`, m)
}

run_lfmcmc_cpp <- function(lfmcmc, params_init_, n_samples_, epsilon_) {
Expand Down Expand Up @@ -264,6 +264,10 @@ set_kernel_fun_cpp <- function(lfmcmc, fun) {
.Call(`_epiworldR_set_kernel_fun_cpp`, lfmcmc, fun)
}

set_rand_engine_lfmcmc_cpp <- function(lfmcmc, eng) {
.Call(`_epiworldR_set_rand_engine_lfmcmc_cpp`, lfmcmc, eng)
}

seed_lfmcmc_cpp <- function(lfmcmc, s) {
.Call(`_epiworldR_seed_lfmcmc_cpp`, lfmcmc, s)
}
Expand Down
10 changes: 9 additions & 1 deletion man/LFMCMC.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ extern "C" SEXP _epiworldR_ModelSEIRMixing_cpp(SEXP name, SEXP n, SEXP prevalenc
END_CPP11
}
// lfmcmc.cpp
SEXP LFMCMC_cpp();
extern "C" SEXP _epiworldR_LFMCMC_cpp() {
SEXP LFMCMC_cpp(SEXP m);
extern "C" SEXP _epiworldR_LFMCMC_cpp(SEXP m) {
BEGIN_CPP11
return cpp11::as_sexp(LFMCMC_cpp());
return cpp11::as_sexp(LFMCMC_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(m)));
END_CPP11
}
// lfmcmc.cpp
Expand Down Expand Up @@ -468,6 +468,13 @@ extern "C" SEXP _epiworldR_set_kernel_fun_cpp(SEXP lfmcmc, SEXP fun) {
END_CPP11
}
// lfmcmc.cpp
SEXP set_rand_engine_lfmcmc_cpp(SEXP lfmcmc, SEXP eng);
extern "C" SEXP _epiworldR_set_rand_engine_lfmcmc_cpp(SEXP lfmcmc, SEXP eng) {
BEGIN_CPP11
return cpp11::as_sexp(set_rand_engine_lfmcmc_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc), cpp11::as_cpp<cpp11::decay_t<SEXP>>(eng)));
END_CPP11
}
// lfmcmc.cpp
SEXP seed_lfmcmc_cpp(SEXP lfmcmc, unsigned long long int s);
extern "C" SEXP _epiworldR_seed_lfmcmc_cpp(SEXP lfmcmc, SEXP s) {
BEGIN_CPP11
Expand Down Expand Up @@ -1016,7 +1023,7 @@ extern "C" SEXP _epiworldR_distribute_virus_to_set_cpp(SEXP agents_ids) {

extern "C" {
static const R_CallMethodDef CallEntries[] = {
{"_epiworldR_LFMCMC_cpp", (DL_FUNC) &_epiworldR_LFMCMC_cpp, 0},
{"_epiworldR_LFMCMC_cpp", (DL_FUNC) &_epiworldR_LFMCMC_cpp, 1},
{"_epiworldR_ModelDiffNet_cpp", (DL_FUNC) &_epiworldR_ModelDiffNet_cpp, 8},
{"_epiworldR_ModelSEIRCONN_cpp", (DL_FUNC) &_epiworldR_ModelSEIRCONN_cpp, 7},
{"_epiworldR_ModelSEIRDCONN_cpp", (DL_FUNC) &_epiworldR_ModelSEIRDCONN_cpp, 8},
Expand Down Expand Up @@ -1140,6 +1147,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_epiworldR_set_prob_recovery_fun_cpp", (DL_FUNC) &_epiworldR_set_prob_recovery_fun_cpp, 3},
{"_epiworldR_set_prob_recovery_ptr_cpp", (DL_FUNC) &_epiworldR_set_prob_recovery_ptr_cpp, 3},
{"_epiworldR_set_proposal_fun_cpp", (DL_FUNC) &_epiworldR_set_proposal_fun_cpp, 2},
{"_epiworldR_set_rand_engine_lfmcmc_cpp", (DL_FUNC) &_epiworldR_set_rand_engine_lfmcmc_cpp, 2},
{"_epiworldR_set_recovery_enhancer_cpp", (DL_FUNC) &_epiworldR_set_recovery_enhancer_cpp, 2},
{"_epiworldR_set_recovery_enhancer_fun_cpp", (DL_FUNC) &_epiworldR_set_recovery_enhancer_fun_cpp, 3},
{"_epiworldR_set_recovery_enhancer_ptr_cpp", (DL_FUNC) &_epiworldR_set_recovery_enhancer_ptr_cpp, 3},
Expand Down
19 changes: 17 additions & 2 deletions src/lfmcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ using namespace epiworld;
// https://github.com/UofUEpiBio/epiworld/tree/master/include/epiworld/math/lfmcmc

[[cpp11::register]]
SEXP LFMCMC_cpp() {
SEXP LFMCMC_cpp(
SEXP m
) {
WrapLFMCMC(lfmcmc_ptr)(
new LFMCMC<TData_default>()
);

lfmcmc_ptr->set_rand_engine(cpp11::external_pointer<Model<>>(m)->get_rand_endgine());

return lfmcmc_ptr;
}

Expand Down Expand Up @@ -131,7 +135,6 @@ SEXP set_summary_fun_cpp(
}

// LFMCMC Kernel Function
// TODO: clean up these really long lines
[[cpp11::register]]
SEXP create_LFMCMCKernelFun_cpp(
cpp11::function fun
Expand Down Expand Up @@ -159,6 +162,18 @@ SEXP set_kernel_fun_cpp(
return lfmcmc;
}

// Rand Engine
[[cpp11::register]]
SEXP set_rand_engine_lfmcmc_cpp(
SEXP lfmcmc,
SEXP eng
) {
cpp11::external_pointer<std::mt19937> eng_ptr(eng);
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
lfmcmc_ptr->set_rand_engine(*eng_ptr);
return lfmcmc;
}

// s should be of type epiworld_fast_uint
[[cpp11::register]]
SEXP seed_lfmcmc_cpp(
Expand Down
13 changes: 6 additions & 7 deletions vignettes/likelihood-free-mcmc.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ print(model_sir)

## Extract Observed data
```{r extract-obs-data}
obs_data <- get_today_total(model_sir)
obs_data <- as.integer(get_today_total(model_sir))
```

## Setup LFMCMC
Expand Down Expand Up @@ -100,18 +100,17 @@ par0 <- c(.5, .5)
## Run LFMCMC
```{r lfmcmc-run}
# TODO: make these work
lfmcmc_model <- LFMCMC() |>
lfmcmc_model <- LFMCMC(model_sir) |>
set_simulation_fun(simfun) |>
set_summary_fun(sumfun) |>
set_proposal_fun(propfun) |>
set_kernel_fun(kernfun)
# set_observed_data(obs_dat) |>
# run_lfmcmc(par0, 2000, 1)
set_kernel_fun(kernfun) |>
set_observed_data(obs_data)
# run_lfmcmc(par0, 2000, 1)
# lfmcmc_model
# lfmcmc_model <- seed(lfmcmc_model, model_seed) |>
# set_par_names(c("Immune recovery", "Infectiousness")) |>
# lfmcmc_model <- set_par_names(c("Immune recovery", "Infectiousness")) |>
# set_stats_names(get_states(model_sir)) |>
# print()
```

0 comments on commit f48be56

Please sign in to comment.