Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: brms dev ideas #383

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ License: MIT + file LICENSE
Depends:
R (>= 3.5)
Imports:
brms,
checkmate,
cli,
data.table,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(as_summary_list,bbi_summary_log_df)
S3method(build_data,bbi_stan_model)
S3method(build_path_from_model,bbi_model)
S3method(build_path_from_model,character)
S3method(check_ext,bbi_nonmem_model)
S3method(check_ext,bbi_nonmem_summary)
S3method(check_ext,character)
Expand Down
8 changes: 8 additions & 0 deletions R/aaa-stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ STAN_MODEL_FILES_TO_CHECK <- c(
STANINIT_SUFFIX
)

STAN_FILES_TO_PRINT <- c(
STANINIT_SUFFIX,
STANARGS_SUFFIX,
STAN_MODEL_FIT_RDS
)

STAN_RESERVED_ARGS <- c(
"data",
"init",
Expand All @@ -28,6 +34,8 @@ STANCFG_ARGS_MD5 <- "stanargs_md5"

STAN_BBI_VERSION_STRING <- "STAN"

STANDATA_BRMS_COMMENT <- "# standata created by brms"

############
# SCAFFOLDS
############
Expand Down
35 changes: 25 additions & 10 deletions R/build-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' @param .mod a `bbi_{.model_type}_model` object.
#' @param .out_path If `NULL`, the default, does not write any data to disk.
#' Otherwise, pass a file path where the resulting data object should be
#' written.
#' written. If a file already exists at this path, it will be overwritten.
#' @param ... Arguments passed through to methods (currently none).
#'
#' @importFrom checkmate assert_string
Expand All @@ -39,18 +39,31 @@ build_data <- function(.mod, .out_path = NULL, ...) {
#'
#' @export
build_data.bbi_stan_model <- function(.mod, .out_path = NULL, ...) {
# source and call function
standata_r_path <- build_path_from_model(.mod, STANDATA_R_SUFFIX)
make_standata <- safe_source_function(standata_r_path, "make_standata")
standata_list <- safe_call_sourced(
.func = make_standata,
.args = list(.dir = dirname(get_output_dir(.mod, .check_exists = FALSE))),
.file = standata_r_path,
.expected_class = "list"
)
standata_json_path <- build_path_from_model(.mod, STANDATA_JSON_SUFFIX)

# check if it's a placeholder from brms
brmsbool <- FALSE
if (file_matches_string(standata_r_path, STANDATA_BRMS_COMMENT)) {
brmsbool <- TRUE
message("standata constructed with brms. Loading directly from json.")
if (!fs::file_exists(standata_json_path)) {
stop(glue("{basename(standata_r_path)} was constructed by brms, but no data file exists at {standata_json_path}"), call. = F)
}
standata_list <- jsonlite::fromJSON(standata_json_path)
} else {
# source and call function
make_standata <- safe_source_function(standata_r_path, "make_standata")
standata_list <- safe_call_sourced(
.func = make_standata,
.args = list(.dir = dirname(get_output_dir(.mod, .check_exists = FALSE))),
.file = standata_r_path,
.expected_class = "list"
)
}

# optionally write to json
if (!is.null(.out_path)) {
if (!is.null(.out_path) && !all(brmsbool, .out_path == standata_json_path)) {
if (!requireNamespace("cmdstanr", quietly = TRUE)) {
stop("Must have cmdstanr installed to use build_data.bbi_stan_model()")
}
Expand All @@ -59,6 +72,8 @@ build_data.bbi_stan_model <- function(.mod, .out_path = NULL, ...) {
if (!str_detect(.out_path, ".json$")) {
stop(glue("build_data.bbi_stan_model(.out_path) must end in '.json' because a JSON file will be written. Got {.out_path}"), .call = FALSE)
}

if (fs::file_exists(.out_path)) fs::file_delete(.out_path)
cmdstanr::write_stan_json(standata_list, .out_path)
}

Expand Down
10 changes: 10 additions & 0 deletions R/copy-model-from.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ copy_stan_files <- function(.parent_mod, .new_model, .overwrite) {
)
}
})

# if -standata.R is a brms stub, copy through -standata.json
parent_standata_r_path <- build_path_from_model(.parent_mod, STANDATA_R_SUFFIX)
parent_standata_json_path <- build_path_from_model(.parent_mod, STANDATA_JSON_SUFFIX)
if (file_matches_string(parent_standata_r_path, STANDATA_BRMS_COMMENT)) {
fs::file_copy(
parent_standata_json_path,
build_path_from_new_model_path(.new_model, STANDATA_JSON_SUFFIX)
)
}
}

#' Private helper to build absolute path for [copy_model_from()].
Expand Down
16 changes: 15 additions & 1 deletion R/get-path-from-object.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ build_path_from_model <- function(.mod, .suffix, ...) {
UseMethod("build_path_from_model")
}

#' @rdname build_path_from_model
#' @describeIn build_path_from_model Takes any `bbi_model` object
#' @export
build_path_from_model.bbi_model <- function(.mod, .suffix, ...) {
file.path(
Expand All @@ -252,6 +252,20 @@ build_path_from_model.bbi_model <- function(.mod, .suffix, ...) {
)
}

#' @describeIn build_path_from_model Takes an absolute model path (without file extension)
#' @importFrom checkmate assert_string
#' @importFrom fs is_absolute_path
#' @export
build_path_from_model.character <- function(.mod, .suffix, ...) {
checkmate::assert_string(.mod)
if (!fs::is_absolute_path(.mod)) stop(paste("Can only pass a `bbi_model` object or an absolute path to `build_path_from_model(). Passed", .mod))

file.path(
.mod,
paste0(get_model_id(.mod), .suffix)
)
}


###################################################################
# Model-specific get path from bbi object implementation functions
Expand Down
97 changes: 83 additions & 14 deletions R/new-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ new_model <- function(
.bbi_args = NULL,
.overwrite = FALSE,
.star = NULL,
.model_type = c("nonmem", "stan")
.model_type = c("nonmem", "stan"),
...
) {

.model_type <- match.arg(.model_type)
Expand All @@ -63,18 +64,30 @@ new_model <- function(
basename(.path)
)

# create model object
.mod <- list()
.mod[[ABS_MOD_PATH]] <- abs_mod_path
.mod[[YAML_MOD_TYPE]] <- .model_type
.mod <- create_model_object(.mod, save_yaml = TRUE)

# update model from passed args
if (!is.null(.description)) .mod <- replace_description(.mod, .description)
if (!is.null(.tags)) .mod <- replace_all_tags(.mod, .tags)
if (!is.null(.bbi_args)) .mod <- replace_all_bbi_args(.mod, .bbi_args)
if (!is.null(.based_on)) .mod <- replace_all_based_on(.mod, .based_on)
if (isTRUE(.star)) .mod <- add_star(.mod)
tryCatch({
parse_new_model_dots(.model_type, abs_mod_path, ...)

# create model object
.mod <- list()
.mod[[ABS_MOD_PATH]] <- abs_mod_path
.mod[[YAML_MOD_TYPE]] <- .model_type
.mod <- create_model_object(.mod, save_yaml = TRUE)

# update model from passed args
if (!is.null(.description)) .mod <- replace_description(.mod, .description)
if (!is.null(.tags)) .mod <- replace_all_tags(.mod, .tags)
if (!is.null(.bbi_args)) .mod <- replace_all_bbi_args(.mod, .bbi_args)
if (!is.null(.based_on)) .mod <- replace_all_based_on(.mod, .based_on)
if (isTRUE(.star)) .mod <- add_star(.mod)
},
error = function(.e) {
if (fs::dir_exists(abs_mod_path)) fs::dir_delete(abs_mod_path)
files_to_kill <- c(yaml_ext(abs_mod_path), ctl_ext(abs_mod_path), mod_ext(abs_mod_path))
purrr::walk(files_to_kill, ~ {if (fs::file_exists(.x)) fs::file_delete(.x)})
stop(paste("new_model() failed for", abs_mod_path, "with the following error:\n", paste(.e, collapse = "\n")))
}
)

return(.mod)
}

Expand Down Expand Up @@ -167,6 +180,63 @@ check_for_existing_model <- function(.path, .overwrite) {
}
}

#' Parse ... passed to new_model()
#'
#' @importFrom ellipsis check_dots_empty
#' @keywords internal
parse_new_model_dots <- function(.model_type, .path, ...) {
args <- list(...)

if (.model_type == "nonmem") {
if (length(args) > 0) {
stop(paste(
"You have passed extra arguments to `new_model()` via `...`, which is NOT valid for NONMEM models.",
NONMEM_MODEL_TYPE_ERR_MSG,
glue("The extra passed arguments are {paste(names(args), collapse = ', ')}"),
sep = "\n"), call. = FALSE)
}
} else if (.model_type == "stan") {
if (length(args) > 0) {
if (all(c("formula", "data") %in% names(args))) {
stan_files_from_brms(.path, args)
} else {
stop(paste(
"You have passed extra arguments to `new_model(.model_type = 'stan')` via `...`",
"This is used for constructing a `bbi_stan_model` with `brms` and REQUIRES both `formula` and `data` to be passed.",
glue("The extra passed arguments are {paste(names(args), collapse = ', ')}"),
sep = "\n"), call. = FALSE)
}
}
}
}


#' Write necessary stan files to disk from brms args
#'
#' @importFrom readr write_lines
#'
#' @param .path absolute model path (files will be created in this dir)
#' @param args named list of arguments to pass to [brms::make_stancode] and [brms::make_standata]
#' @keywords internal
stan_files_from_brms <- function(.path, args) {
if (!requireNamespace("cmdstanr", quietly = TRUE)) {
stop("Must have cmdstanr installed to use create a `bbi_stan_model` with `brms`")
}
if (!requireNamespace("brms", quietly = TRUE)) {
stop("Must have brms installed to use create a `bbi_stan_model` with `brms`")
}
if (!fs::dir_exists(.path)) fs::dir_create(.path)

# write .stan file
stan_code <- do.call(brms::make_stancode, args)
write_lines(stan_code, build_path_from_model(.path, STANMOD_SUFFIX))

# write data file
stan_data <- do.call(brms::make_standata, args)
cmdstanr::write_stan_json(unclass(stan_data), build_path_from_model(.path, STANDATA_JSON_SUFFIX))
write_lines(STANDATA_BRMS_COMMENT, build_path_from_model(.path, STANDATA_R_SUFFIX))
}

#' Private helper to remove file extensions to match expected input to new model.
#' @inheritParams new_model
#' @keywords internal
Expand All @@ -178,4 +248,3 @@ sanitize_file_extension <- function(.path)
}
return(.path)
}

5 changes: 4 additions & 1 deletion R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ print.bbi_model <- function(x, ...) {
bullet_list(get_model_path(x, .check_exists = FALSE))
if (inherits(x, STAN_MOD_CLASS)) {
bullet_list(build_path_from_model(x, STANDATA_R_SUFFIX))
bullet_list(build_path_from_model(x, STANINIT_SUFFIX))
for (.s in STAN_FILES_TO_PRINT) {
.f <- build_path_from_model(x, .s)
if (fs::file_exists(.f)) bullet_list(.f)
}
check_stan_model(x)
}

Expand Down
5 changes: 3 additions & 2 deletions R/submit-model-stan-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ parse_stanargs <- function(.mod, valid_stanargs, ...) {
if (any(names(stanargs) %in% STAN_RESERVED_ARGS)) {
stop(paste(
"Cannot pass any of the following through submit_model() to cmdstanr",
glue("because they are parsed internally from the model object: {paste(STAN_RESERVED_ARGS, collapse = ', ')}")
))
glue("because they are parsed internally from the model object: {paste(STAN_RESERVED_ARGS, collapse = ', ')} --\n"),
"Use add_standata_file() or add_stan_init() instead. See ?bbi_stan_model for more details."
), call. = FALSE)
}

invalid_stanargs <- setdiff(names(stanargs), valid_stanargs)
Expand Down
2 changes: 0 additions & 2 deletions R/submit-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ submit_stan_model_cmdstanr <- function(.mod,
if (isTRUE(.overwrite)) {
fs::dir_delete(out_dir)
fs::dir_create(out_dir)
if(fs::file_exists(standata_json_path)) { fs::file_delete(standata_json_path) }
if(fs::file_exists(stanargs_path)) { fs::file_delete(stanargs_path) }
} else {
stop(glue("{out_dir} already exists. Pass submit_model(..., .overwrite = TRUE) to delete it and re-run the model."), call. = FALSE)
}
Expand Down
Loading