Skip to content

Commit

Permalink
Merge pull request #276 from rbchan/gdistsamp_refactor2
Browse files Browse the repository at this point in the history
Refactor gdistsamp likelihood to avoid occasional crashes
  • Loading branch information
kenkellner authored Mar 3, 2024
2 parents 85009ab + 709d3d0 commit 8c06592
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
7 changes: 4 additions & 3 deletions R/gdistsamp.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,9 @@ if(engine =="C"){
as.vector(t(out))
}
y_long <- long_format(y)
kmytC <- kmyt
kmytC[which(is.na(kmyt))] <- 0
# Vectorize these arrays as using arma::subcube sometimes crashes
kmytC <- as.vector(aperm(kmyt, c(3,2,1)))
lfac.kmytC <- as.vector(aperm(lfac.kmyt, c(3,2,1)))
if(output!='density'){
A <- rep(1, M)
}
Expand All @@ -411,7 +412,7 @@ if(engine =="C"){
nll <- function(params){
nll_gdistsamp(params, n_param, y_long, mixture_code, keyfun, survey,
Xlam, Xlam.offset, A, Xphi, Xphi.offset, Xdet, Xdet.offset,
db, a, t(u), w, k, lfac.k, lfac.kmyt, kmyt, Kmin, threads)
db, a, t(u), w, k, lfac.k, lfac.kmytC, kmytC, Kmin, threads)
}

} else {
Expand Down
6 changes: 3 additions & 3 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ BEGIN_RCPP
END_RCPP
}
// nll_gdistsamp
double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads);
double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::vec lfac_kmyt, arma::vec kmyt, arma::uvec Kmin, int threads);
RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP kSEXP, SEXP lfac_kSEXP, SEXP lfac_kmytSEXP, SEXP kmytSEXP, SEXP KminSEXP, SEXP threadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Expand All @@ -165,8 +165,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< arma::vec >::type w(wSEXP);
Rcpp::traits::input_parameter< arma::vec >::type k(kSEXP);
Rcpp::traits::input_parameter< arma::vec >::type lfac_k(lfac_kSEXP);
Rcpp::traits::input_parameter< arma::cube >::type lfac_kmyt(lfac_kmytSEXP);
Rcpp::traits::input_parameter< arma::cube >::type kmyt(kmytSEXP);
Rcpp::traits::input_parameter< arma::vec >::type lfac_kmyt(lfac_kmytSEXP);
Rcpp::traits::input_parameter< arma::vec >::type kmyt(kmytSEXP);
Rcpp::traits::input_parameter< arma::uvec >::type Kmin(KminSEXP);
Rcpp::traits::input_parameter< int >::type threads(threadsSEXP);
rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads));
Expand Down
34 changes: 24 additions & 10 deletions src/nll_gdistsamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi,
arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db,
arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k,
arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads){
arma::vec lfac_kmyt, arma::vec kmyt, arma::uvec Kmin, int threads){

#ifdef _OPENMP
omp_set_num_threads(threads);
Expand All @@ -27,7 +27,8 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
int T = Xphi.n_rows / M;
int R = y.size() / M;
unsigned J = R / T;
int K = k.size() - 1;
int lk = k.size();
int K = lk - 1;

//Abundance
const vec lambda = exp(Xlam * beta_sub(beta, n_param, 0) + Xlam_offset) % A;
Expand All @@ -53,9 +54,18 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,

int t_ind = i * T;
int y_ind = i * T * J;
int k_start = i * T * lk;

vec y_sub(J);

vec p(J);
vec p1(lk);
vec p3(J);
vec p4(lk);
double p5;

//Some unnecessary calculations here when k < Kmin
//These values are ignored later in calculation of site_lp
//However hard to avoid without refactoring entirely I think
mat mn = zeros(K+1, T);
for(int t=0; t<T; t++){
int y_stop = y_ind + J - 1;
Expand All @@ -64,23 +74,27 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,

if(not_missing.size() == J){

vec p1 = lfac_kmyt.subcube(span(i),span(t),span());
vec p = distprob(keyfun, det_param(t_ind), scale, survey, db,
int k_stop = k_start + lk - 1;

p1 = lfac_kmyt.subvec(k_start, k_stop);

p = distprob(keyfun, det_param(t_ind), scale, survey, db,
w, a.row(i));
vec p3 = p % u.col(i) * phi(t_ind);
//the following line causes a segfault only in R CMD check,
//when kmyt contains NA values
vec p4 = kmyt.subcube(span(i),span(t),span());
p3 = p % u.col(i) * phi(t_ind);

p4 = kmyt.subvec(k_start, k_stop);

double p5 = 1 - sum(p3);
p5 = 1 - sum(p3);

mn.col(t) = lfac_k - p1 + sum(y_sub % log(p3)) + p4 * log(p5);
}

t_ind += 1;
y_ind += J;
k_start += lk;
}

//Note that rows of mn for k < Kmin are skipped here
double site_lp = 0.0;
for (int j=Kmin(i); j<(K+1); j++){
site_lp += N_density(mixture, j, lambda(i), log_alpha) *
Expand Down

0 comments on commit 8c06592

Please sign in to comment.