Skip to content

Commit

Permalink
adding log_prior argument to dsemRTMB
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Thorson committed Dec 31, 2024
1 parent 380ad44 commit 829b0c8
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 12 deletions.
18 changes: 14 additions & 4 deletions R/get_jnll.R → R/compute_nll.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# model function
get_jnll <-
compute_nll <-
function( parlist,
model,
y_tj,
family,
options ) {
options,
log_prior ) {
# options[1] -> 0: full rank; 1: rank-reduced GMRF
# options[2] -> 0: constant conditional variance; 1: constant marginal variance

Expand Down Expand Up @@ -61,7 +62,9 @@ function( parlist,
# Full rank GMRF
z_tj = x_tj

# Doesn't work
# Doesn't work ... mat2triplet not implemented
#V_kk = matrix(0, nrow=n_k, ncol=n_k)
#REPORT( V_kk )
#V_kk = as.matrix(V_kk)
#invV_kk = solve(V_kk)
#Qtmp_kk = invV_kk
Expand All @@ -75,6 +78,8 @@ function( parlist,
Q_kk = t(IminusRho_kk) %*% invV_kk %*% IminusRho_kk

# Experiment
#IminusRho_dense = as.matrix( IminusRho_kk )
#V_dense = as.matrix( V_kk )
#Q_RHS = solve(V_kk, IminusRho_kk)
#Q_kk = t(IminusRho_kk) %*% Q_RHS

Expand Down Expand Up @@ -157,8 +162,12 @@ function( parlist,
devresid_tj[t,j] = sign(y_tj[t,j] - mu_tj[t,j]) * pow(2 * ( (y_tj[t,j]-mu_tj[t,j])/mu_tj[t,j] - log(y_tj[t,j]/mu_tj[t,j]) ), 0.5);
}
}}

# Calculate priors
log_prior_value = log_prior( parlist )

jnll = -1 * sum(loglik_tj);
jnll = jnll + jnll_gmrf;
jnll = jnll + jnll_gmrf - log_prior_value;

#
REPORT( loglik_tj )
Expand All @@ -175,6 +184,7 @@ function( parlist,
REPORT( jnll )
REPORT( loglik_tj )
REPORT( jnll_gmrf )
REPORT( log_prior_value )
#SIMULATE{
# REPORT( y_tj )
#}
Expand Down
2 changes: 1 addition & 1 deletion R/dsem.R
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ function( sem,
obj$gr_orig = obj$gr

# BUild prior evaluator
requireNamespace(RTMB)
requireNamespace("RTMB")
priors_obj = RTMB::MakeADFun( func = prior_negloglike,
parameters = list(par=obj$par),
silent = TRUE )
Expand Down
15 changes: 12 additions & 3 deletions R/dsemRTMB.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function( sem,
tsdata,
family = rep("fixed",ncol(tsdata)),
estimate_delta0 = FALSE,
prior_negloglike = NULL,
log_prior = function(p) 0,
control = dsem_control(),
covs = colnames(tsdata) ){

Expand Down Expand Up @@ -50,6 +50,14 @@ function( sem,
n_j = ncol(y_tj)
n_k = prod(dim(y_tj))

# Load data in environment for function "dBdt"
data4 = local({
"c" <- ADoverload("c")
"[<-" <- ADoverload("[<-")
environment()
})
environment(log_prior) <- data4

# Construct parameters
if( is.null(control$parameters) ){
Params = list(
Expand Down Expand Up @@ -110,11 +118,12 @@ function( sem,
cmb <- function(f, ...) function(p) f(p, ...) ## Helper to make closure
#f(parlist, model, tsdata, family)
obj = RTMB::MakeADFun(
func = cmb( get_jnll,
func = cmb( compute_nll,
model = model,
y_tj = y_tj,
family = family,
options = options ),
options = options,
log_prior = log_prior ),
parameters = Params,
random = Random,
map = Map,
Expand Down
21 changes: 17 additions & 4 deletions scratch/test_dsemRTMB.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

devtools::install_local( R'(C:\Users\James.Thorson\Desktop\Git\dsem)', force=TRUE, dep=FALSE )
#devtools::install_local( R'(C:\Users\James.Thorson\Desktop\Git\dsem)', force=TRUE, dep=FALSE )

library(dsem)
library(RTMB)
Expand Down Expand Up @@ -44,11 +44,15 @@ Map$lnsigma_j = factor( rep(NA,ncol(tsdata)) )
Params = fit0$tmb_inputs$parameters
Params$lnsigma_j[] = log(0.1)

#
prior_negloglike = \(obj) -dnorm(obj$par[1],0,0.1,log=TRUE)

# Fit model
fit = dsem( sem=sem,
tsdata = tsdata,
estimate_delta0 = TRUE,
family = rep("normal",ncol(tsdata)),
prior_negloglike = prior_negloglike,
control = dsem_control( quiet=TRUE,
run_model = TRUE,
use_REML = TRUE,
Expand All @@ -62,20 +66,29 @@ if( FALSE ){
covs = colnames(tsdata)
}

#

###################
# dsemRTMB
###################

# Files
source( file.path(R'(C:\Users\James.Thorson\Desktop\Git\dsem\R)', "make_matrices.R") )
source( file.path(R'(C:\Users\James.Thorson\Desktop\Git\dsem\R)', "get_jnll.R") )
source( file.path(R'(C:\Users\James.Thorson\Desktop\Git\dsem\R)', "compute_nll.R") )
source( file.path(R'(C:\Users\James.Thorson\Desktop\Git\dsem\R)', "read_model.R") )
source( file.path(R'(C:\Users\James.Thorson\Desktop\Git\dsem\R)', "dsemRTMB.R") )

# Define prior
log_prior = function(p) dnorm( p$beta_z[1], mean=0, sd=0.1, log=TRUE)

fitRTMB = dsemRTMB( sem = sem,
tsdata = tsdata,
estimate_delta0 = TRUE,
family = rep("normal",ncol(tsdata)),
log_prior = log_prior,
control = dsem_control( quiet = FALSE,
run_model = TRUE,
use_REML = TRUE,
gmrf_parameterization = "projection",
trace = 1,
map = Map,
parameters = Params ) )
obj = fitRTMB$obj
Expand Down

0 comments on commit 829b0c8

Please sign in to comment.