Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 735: Fix truncation slicing when t < truncation #736

Merged
merged 9 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk in #716 and reviewed by @jamesmbaazam.
- a bug was fixed where `forecast_secondary()` did not work with fixed delays. By @sbfnk in #717 and reviewed by @seabbs.
- a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk.
- a bug was fixed that led to the truncation PMF being shorten from the wrong side when the truncation PMF was longer than the supplied data. By @seabbs in #736 and reviewed by @sbfnk.
seabbs marked this conversation as resolved.
Show resolved Hide resolved

## Documentation

Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ transformed parameters {
);
}
profile("truncate") {
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
obs_reports = truncate_obs(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
obs_reports = reports[1:ot];
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ transformed parameters {
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(secondary, trunc_rev_cmf, 0);
secondary = truncate_obs(secondary, trunc_rev_cmf, 0);
}
}

Expand Down
6 changes: 3 additions & 3 deletions inst/stan/estimate_truncation.stan
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ transformed parameters{
vector[t] last_obs;
// reconstruct latest data without truncation

last_obs = truncate(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
last_obs = truncate_obs(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
// apply truncation to latest dataset to map back to previous data sets and
// add noise term
for (i in 1:(obs_sets - 1)) {
trunc_obs[1:(end_t[i] - start_t[i] + 1), i] =
truncate(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
truncate_obs(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
}
}
}
Expand Down Expand Up @@ -80,7 +80,7 @@ generated quantities {
matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs;
// reconstruct all truncated datasets using posterior of the truncation distribution
for (i in 1:obs_sets) {
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate(
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs(
to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1
);
}
Expand Down
110 changes: 97 additions & 13 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
// apply day of the week effect
/**
* Apply day of the week effect to reports
*
* This function applies a day of the week effect to a vector of reports.
*
* @param reports Vector of reports to be adjusted.
* @param day_of_week Array of integers representing the day of the week for each report.
* @param effect Vector of day of week effects.
*
* @return A vector of reports adjusted for day of the week effects.
*/
vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect) {
int t = num_elements(reports);
int wl = num_elements(effect);
Expand All @@ -11,30 +21,65 @@ vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect
}
return(scaled_reports);
}
// Scale observations by fraction reported and update log density of
// fraction reported

/**
* Scale observations by fraction reported
*
* This function scales a vector of reports by a fraction observed.
*
* @param reports Vector of reports to be scaled.
* @param frac_obs Real value representing the fraction observed.
*
* @return A vector of scaled reports.
*/
vector scale_obs(vector reports, real frac_obs) {
int t = num_elements(reports);
vector[t] scaled_reports;
scaled_reports = reports * frac_obs;
return(scaled_reports);
}
// Truncate observed data by some truncation distribution
vector truncate(vector reports, vector trunc_rev_cmf, int reconstruct) {

/**
* Truncate observed data by a truncation distribution
*
* This function truncates a vector of reports based on a truncation distribution.
*
* @param reports Vector of reports to be truncated.
* @param trunc_rev_cmf Vector representing the reverse cumulative mass function of the truncation distribution.
* @param reconstruct Integer flag indicating whether to reconstruct (1) or truncate (0) the data.
*
* @return A vector of truncated reports.
*/
vector truncate_obs(vector reports, vector trunc_rev_cmf, int reconstruct) {
int t = num_elements(reports);
int trunc_max = num_elements(trunc_rev_cmf);
vector[t] trunc_reports = reports;
// Calculate cmf of truncation delay
int trunc_max = min(t, num_elements(trunc_rev_cmf));
int first_t = t - trunc_max + 1;
int joint_max = min(t, trunc_max);
int first_t = t - joint_max + 1;
int first_trunc = trunc_max - joint_max + 1;

// Apply cdf of truncation delay to truncation max last entries in reports
if (reconstruct) {
trunc_reports[first_t:t] ./= trunc_rev_cmf[1:trunc_max];
trunc_reports[first_t:t] ./= trunc_rev_cmf[first_trunc:trunc_max];
} else {
trunc_reports[first_t:t] .*= trunc_rev_cmf[1:trunc_max];
trunc_reports[first_t:t] .*= trunc_rev_cmf[first_trunc:trunc_max];
}
return(trunc_reports);
}
// Truncation distribution priors

/**
* Update log density for truncation distribution priors
*
* This function updates the log density for truncation distribution priors.
*
* @param truncation_mean Array of real values for truncation mean.
* @param truncation_sd Array of real values for truncation standard deviation.
* @param trunc_mean_mean Array of real values for mean of truncation mean prior.
* @param trunc_mean_sd Array of real values for standard deviation of truncation mean prior.
* @param trunc_sd_mean Array of real values for mean of truncation standard deviation prior.
* @param trunc_sd_sd Array of real values for standard deviation of truncation standard deviation prior.
*/
void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
array[] real trunc_mean_mean, array[] real trunc_mean_sd,
array[] real trunc_sd_mean, array[] real trunc_sd_sd) {
Expand All @@ -50,7 +95,22 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
}
// update log density for reported cases

/**
* Update log density for reported cases
*
* This function updates the log density for reported cases based on the specified model type.
*
* @param cases Array of integer observed cases.
* @param cases_time Array of integer time indices for observed cases.
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param phi_mean Real value for mean of reporting overdispersion prior.
* @param phi_sd Real value for standard deviation of reporting overdispersion prior.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
* @param weight Real value for weighting the log density contribution.
* @param accumulate Integer flag indicating whether to accumulate reports (1) or not (0).
*/
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight, int accumulate) {
Expand Down Expand Up @@ -96,7 +156,20 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports,
}
}
}
// update log likelihood (as above but not vectorised and returning log likelihood)

/**
* Calculate log likelihood for reported cases
*
* This function calculates the log likelihood for reported cases based on the specified model type.
*
* @param cases Array of integer observed cases.
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
* @param weight Real value for weighting the log likelihood contribution.
*
* @return A vector of log likelihoods for each time point.
*/
vector report_log_lik(array[] int cases, vector reports,
array[] real rep_phi, int model_type, real weight) {
int t = num_elements(reports);
Expand All @@ -115,7 +188,18 @@ vector report_log_lik(array[] int cases, vector reports,
}
return(log_lik);
}
// sample reported cases from the observation model

/**
* Generate random samples of reported cases
*
* This function generates random samples of reported cases based on the specified model type.
*
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
*
* @return An array of integer sampled reports.
*/
array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
int t = num_elements(reports);
array[t] int sampled_reports;
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
reports[i] = to_row_vector(truncate(
reports[i] = to_row_vector(truncate_obs(
to_vector(reports[i]), trunc_rev_cmf, 0)
);
}
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(
secondary = truncate_obs(
secondary, trunc_rev_cmf, 0
);
}
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-stan-truncate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
skip_on_cran()
skip_on_os("windows")

test_that("truncate_obs() can perform truncation as expected", {
reports <- c(10, 20, 30, 40, 50)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] * trunc_rev_cmf)
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate_obs() can perform reconstruction as expected", {
reports <- c(10, 20, 15, 8, 10)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] / trunc_rev_cmf)
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})

test_that("truncate_obs() can handle longer trunc_rev_cmf than reports", {
reports <- c(10, 20, 30)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports * trunc_rev_cmf[3:5]
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate_obs() can handle reconstruction with longer trunc_rev_cmf than reports", {
reports <- c(10, 16, 15)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports / trunc_rev_cmf[3:5]
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})
Loading