Skip to content

Commit

Permalink
Enhance R wrapper functions to support length data and add integratio…
Browse files Browse the repository at this point in the history
…n tests

- Extend R wrapper functions to handle three scenarios: age input only,
  length input only, and combined age and length inputs.
- Remove `tests/testthat/fixtures/simulate-integration-test-data.R`
  and prepare test fixtures using `data-raw/data1.R` to eliminate
  duplicated OM simulation code.
- Refactor helper functions to consistently source age comps, length
  comps, and the age-to-length conversion matrix from OM datasets
  (`om_input$` or `em_input$`).
- Add integration tests to validate the functionality across the
  three input scenarios.
- Update NLL tests in integration tests without wrapper
  functions, as direct comparison of `jnll` from reports with
  OM-derived `jnll` is not feasible.
- Update `fims-demo` to demonstrate model runs using both age
  and length composition data.
- TODO: update the length comp likelihoods section from the "nll test of
  fims" in both
  `tests/testthat/test-integration-fims-estimation-with-wrappers.R` and
  `tests/testthat/test-integration-fims-estimation-without-wrappers.R

Co-authored-by: @nathanvaughan-NOAA <[email protected]>
Co-authored-by: @JaneSullivan-NOAA <[email protected]>
  • Loading branch information
3 people committed Dec 23, 2024
1 parent ec6dc90 commit c5f03ee
Show file tree
Hide file tree
Showing 17 changed files with 946 additions and 829 deletions.
20 changes: 20 additions & 0 deletions R/fimsfit.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,68 +108,82 @@ NULL
#' is the returned object from [create_default_parameters()].
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
methods::setGeneric("get_input", function(x) standardGeneric("get_input"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
methods::setMethod("get_input", "FIMSFit", function(x) x@input)

#' @return
#' [get_report()] returns the TMB report, where anything that is flagged as
#' reportable in the C++ code is returned.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_report", function(x) standardGeneric("get_report"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_report", "FIMSFit", function(x) x@report)

#' @return
#' [get_obj()] returns the output from [TMB::MakeADFun()].
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_obj", function(x) standardGeneric("get_obj"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_obj", "FIMSFit", function(x) x@obj)

#' @return
#' [get_opt()] returns the output from [nlminb()], which is the minimizer used
#' in [fit_fims()].
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_opt", function(x) standardGeneric("get_opt"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_opt", "FIMSFit", function(x) x@opt)

#' @return
#' [get_max_gradient()] returns the maximum gradient found when optimizing the
#' model.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_max_gradient", function(x) standardGeneric("get_max_gradient"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_max_gradient", "FIMSFit", function(x) x@max_gradient)


#' @return
#' [get_sdreport()] returns the list from [TMB::sdreport()].
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_sdreport", function(x) standardGeneric("get_sdreport"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_sdreport", "FIMSFit", function(x) x@sdreport)

#' @return
#' [get_estimates()] returns a tibble of parameter values and their
#' uncertainties from a fitted model.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_estimates", function(x) standardGeneric("get_estimates"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_estimates", "FIMSFit", function(x) x@estimates)

#' @return
Expand All @@ -178,12 +192,14 @@ setMethod("get_estimates", "FIMSFit", function(x) x@estimates)
#' in the model.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric(
"get_number_of_parameters",
function(x) standardGeneric("get_number_of_parameters")
)
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod(
"get_number_of_parameters",
"FIMSFit",
Expand All @@ -195,19 +211,23 @@ setMethod(
#' seconds as a `difftime` object.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_timing", function(x) standardGeneric("get_timing"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_timing", "FIMSFit", function(x) x@timing)

#' @return
#' [get_version()] returns the `package_version` of FIMS that was used to fit
#' the model.
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setGeneric("get_version", function(x) standardGeneric("get_version"))
#' @export
#' @rdname get_FIMSFit
#' @keywords fit_fims
setMethod("get_version", "FIMSFit", function(x) x@version)

# methods::setValidity ----
Expand Down
5 changes: 3 additions & 2 deletions R/fimsframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,13 @@ FIMSFrame <- function(data) {
# Get the range of ages displayed in the data to use to specify population
# simulation range
if ("age" %in% colnames(data)) {
ages <- min(data[["age"]], na.rm = TRUE):max(data[["age"]], na.rm = TRUE)
ages <- sort(unique(data[["age"]]))
ages <- ages[!is.na(ages)]
} else {
ages <- numeric()
}
n_ages <- length(ages)

if ("length" %in% colnames(data)) {
lengths <- sort(unique(data[["length"]]))
lengths <- lengths[!is.na(lengths)]
Expand Down
71 changes: 51 additions & 20 deletions R/initialize_modules.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,17 @@ initialize_module <- function(parameters, data, module_name) {

# Set the estimation information for the entire parameter vector
module[["age_length_conversion_matrix"]]$set_all_estimable(FALSE)

module[["age_length_conversion_matrix"]]$set_all_random(FALSE)
} else {
module_fields <- setdiff(module_fields, c(
# Right now we can also remove nlengths because the default is 0
"nlengths"
))
}

module_fields <- setdiff(module_fields, c(
"age_length_conversion_matrix",
# Right now we can also remove nlengths because the default is 0
"nlengths",
"proportion_catch_numbers_at_length"
))
}
Expand Down Expand Up @@ -403,12 +408,6 @@ initialize_selectivity <- function(parameters, data, fleet_name) {
#' The initialized fleet module as an object.
#' @noRd
initialize_fleet <- function(parameters, data, fleet_name, linked_ids) {
if (any(is.na(linked_ids[c("selectivity", "index", "age_comp")]))) {
cli::cli_abort(c(
"{.var linked_ids} for {fleet_name} must include 'selectivity', 'index',
and 'age_comp' IDs."
))
}

module <- initialize_module(
parameters = parameters,
Expand All @@ -418,12 +417,23 @@ initialize_fleet <- function(parameters, data, fleet_name, linked_ids) {

module$SetSelectivity(linked_ids["selectivity"])
module$SetObservedIndexData(linked_ids["index"])
module$SetObservedAgeCompData(linked_ids["age_comp"])

fleet_types <- get_data(data) |>
dplyr::filter(name == fleet_name) |>
dplyr::pull(type) |>
unique()

# Link the observed age composition data to the fleet module using its associated ID
# if the data type includes "age" and if "AgeComp" exists in the data distribution
# specification
if ("age" %in% fleet_types &
"AgeComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_name]][["data_distribution"]])) {
module$SetObservedAgeCompData(linked_ids["age_comp"])
}

# Link the observed length composition data to the fleet module using its associated ID
# if the data type includes "length" and if "LengthComp" exists in the data
# distribution specification
if ("length" %in% fleet_types &
"LengthComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_name]][["data_distribution"]])) {
module$SetObservedLengthCompData(linked_ids["length_comp"])
Expand Down Expand Up @@ -611,14 +621,8 @@ initialize_fims <- function(parameters, data) {
fleet_name = fleet_names[i]
)

fleet_age_comp[[i]] <- initialize_age_comp(
data = data,
fleet_name = fleet_names[i]
)

fleet_module_ids <- c(
index = fleet_index[[i]]$get_id(),
age_comp = fleet_age_comp[[i]]$get_id(),
selectivity = fleet_selectivity[[i]]$get_id()
)

Expand All @@ -627,12 +631,36 @@ initialize_fims <- function(parameters, data) {
dplyr::pull(type) |>
unique()

# Initialize age composition module if the data type includes "age" and
# if "AgeComp" exists in the data distribution specification
if ("age" %in% fleet_types &
"AgeComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_names[i]]][["data_distribution"]])) {

# Initialize age composition module for the current fleet
fleet_age_comp[[i]] <- initialize_age_comp(
data = data,
fleet_name = fleet_names[i]
)

# Add the module ID for the initialized age composition to the list of fleet module IDs
fleet_module_ids <- c(
fleet_module_ids,
c(age_comp = fleet_age_comp[[i]]$get_id())
)
}

# Initialize length composition module if the data type includes "length" and
# if "LengthComp" exists in the data distribution specification
if ("length" %in% fleet_types &
"LengthComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_names[i]]][["data_distribution"]])) {

# Initialize length composition module for the current fleet
fleet_length_comp[[i]] <- initialize_length_comp(
data = data,
fleet_name = fleet_names[i]
)

# Add the module ID for the initialized length composition to the list of fleet module IDs
fleet_module_ids <- c(
fleet_module_ids,
c(length_comp = fleet_length_comp[[i]]$get_id())
Expand Down Expand Up @@ -678,11 +706,14 @@ initialize_fims <- function(parameters, data) {
data_type = "index"
)

fleet_agecomp_distribution[[i]] <- initialize_data_distribution(
module = fleet[[i]],
family = multinomial(link = "logit"),
data_type = "agecomp"
)
if ("age" %in% fleet_types &
"AgeComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_names[i]]][["data_distribution"]])) {
fleet_agecomp_distribution[[i]] <- initialize_data_distribution(
module = fleet[[i]],
family = multinomial(link = "logit"),
data_type = "agecomp"
)
}

if ("length" %in% fleet_types &
"LengthComp" %in% names(parameters[["modules"]][["fleets"]][[fleet_names[i]]][["data_distribution"]])) {
Expand Down
Loading

0 comments on commit c5f03ee

Please sign in to comment.