Skip to content

Commit

Permalink
avoid matrix multiplication at every iteration for computing residuals
Browse files Browse the repository at this point in the history
  • Loading branch information
evangorstein committed Dec 5, 2024
1 parent c94eb08 commit f5581e4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 5 additions & 1 deletion src/HighDimMixedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ function hdmm(

#Algorithm allocations
βiter = copy(βstart)
res_iter = y - XG * βiter
Liter = copy(Lstart)
σ²iter = σ²start
Lvec_iter = ndims(Liter) == 2 ? vec(Liter) : Liter #L as vector if ψstr == "sym"
Expand Down Expand Up @@ -255,8 +256,9 @@ function hdmm(

#Update fixed effect parameters that are in active_set
for j in active_set

# we also pass XG and y instead of XGgrp and ygrp for reasons of efficiency--see definition of special_quad
cut = special_quad(XG, y, βiter, j, invVgrp, XGgrp, grp)
cut = special_quad(res_iter, XG[:,j], βiter[j], grp, invVgrp, XGgrp)

if hess[j] == hess_untrunc[j] #Outcome of Armijo rule can be computed analytically
if j in 1:q
Expand Down Expand Up @@ -286,6 +288,8 @@ function hdmm(
control,
)
end
# update residuals
res_iter .-= (βiter[j] - βold[j])*XG[:, j]
end

#---Optimization with respect to random effect parameters ----------------------------
Expand Down
10 changes: 4 additions & 6 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,14 @@ Calculate (y-ỹ)' \\* invV \\* X[:,j], where ỹ are the fitted values if we ig
To improve perforamce, we calculate ỹ at once with the entire dataset.
We then split into groups and calculate (y-ỹ)' \\* invV \\* X[:,j] for each group
"""
function special_quad(XG, y, β, j, invVgrp, XGgrp, grp)
function special_quad(res, XGj, βj, grp, invVgrp, XGgrp)

XGmiss = @views XG[:, [1:j-1 ; j+1:end]]
βmiss = @views β[[1:j-1; j+1:end]]
resid = y - XGmiss * βmiss
res_miss = res + XGj*βj

residgrp = [resid[grp.==group] for group in unique(grp)]
res_miss_grp = [res_miss[grp.==group] for group in unique(grp)]

quads =
[resid' * invV * XG[:, j] for (resid, invV, XG) in zip(residgrp, invVgrp, XGgrp)]
[resid' * invV * XG[:, j] for (resid, invV, XG) in zip(res_miss_grp, invVgrp, XGgrp)]

return sum(quads)

Expand Down

0 comments on commit f5581e4

Please sign in to comment.