Skip to content

Commit

Permalink
Minimize matrix inversion and use Cholesky exclusively for inversion …
Browse files Browse the repository at this point in the history
…for ldiv-ing
  • Loading branch information
droodman committed Dec 23, 2023
1 parent 20bfed1 commit 4ba8316
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 341 deletions.
257 changes: 162 additions & 95 deletions Manifest.toml

Large diffs are not rendered by default.

8 changes: 1 addition & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "WildBootTests"
uuid = "65c2e505-86ba-4c19-93f1-95506c1443d5"
authors = ["droodman <d.roodman@outlook.com>"]
authors = ["droodman <david@davidroodman.com>"]
version = "0.9.11"

[deps]
Expand All @@ -17,10 +17,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[compat]
Distributions = "0.25.96"
StableRNGs = "1.0.0"
LoopVectorization = "0.12.159"
PrecompileTools = "1.1.2"
SortingAlgorithms = "1.1.1"
ThreadsX = "0.1.11"
julia = "1.8"
52 changes: 26 additions & 26 deletions src/WRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function InitWRE!(o::StrBootTest{T}) where T
o.S⋂ȳ₁X = Array{T,3}(undef, 1, o.N⋂, o.DGP.kX)
o.S⋂ReplZ̄X = Array{T,3}(undef, o.Repl.kZ, o.N⋂, o.DGP.kX)

o.S⋂XZperpinvZperpZperp = S⋂ZperpX' * o.DGP.invZperpZperp
o.S⋂XZperpinvZperpZperp = cholldiv(o.DGP.cholZperpZperp, S⋂ZperpX)
o.negS✻UMZperpX = [Array{T,3}(undef, o.DGP.kX, o.N⋂, o.N✻) for _ in 0:o.Repl.kZ]

_inds = o.subcluster>0 ?
Expand Down Expand Up @@ -238,16 +238,16 @@ function InitWRE!(o::StrBootTest{T}) where T
end
end

o.invXXS✻XDGPZ = @panelsum(o.DGP.invXX * o.S✻⋂XDGPZ, o.info✻_✻⋂)
o.invXXS✻Xy₁ = o.DGP.invXX * o.S✻Xy₁
o.invZperpZperpS✻ZperpY₂ = o.DGP.invZperpZperp * o.S✻ZperpY₂
o.invZperpZperpS✻ZperpX = o.DGP.invZperpZperp * o.S✻ZperpX
o.invZperpZperpS✻Zperpy₁ = o.DGP.invZperpZperp * o.S✻Zperpy₁
o.invZperpZperpS✻ZperpDGPZ = o.DGP.invZperpZperp * o.S✻ZperpDGPZ
o.invXXS✻XDGPZ = @panelsum(cholldiv(o.DGP.cholXX, o.S✻⋂XDGPZ), o.info✻_✻⋂)
o.invXXS✻Xy₁ = cholldiv(o.DGP.cholXX, o.S✻Xy₁)
o.invZperpZperpS✻ZperpY₂ = cholldiv(o.DGP.cholZperpZperp, o.S✻ZperpY₂ )
o.invZperpZperpS✻ZperpX = cholldiv(o.DGP.cholZperpZperp, o.S✻ZperpX )
o.invZperpZperpS✻Zperpy₁ = cholldiv(o.DGP.cholZperpZperp, o.S✻Zperpy₁ )
o.invZperpZperpS✻ZperpDGPZ = cholldiv(o.DGP.cholZperpZperp, o.S✻ZperpDGPZ)

if o.DGP.restricted
o.invXXS✻XDGPZR₁ = o.DGP.invXX * o.S✻XZR₁
o.invZperpZperpS✻ZperpDGPZR₁ = o.DGP.invZperpZperp * o.S✻ZperpDGPZR₁
o.invXXS✻XDGPZR₁ = cholldiv(o.DGP.cholXX, o.S✻XZR₁)
o.invZperpZperpS✻ZperpDGPZR₁ = cholldiv(o.DGP.cholZperpZperp, o.S✻ZperpDGPZR₁)
end
end
end
Expand All @@ -266,7 +266,7 @@ function PrepWRE!(o::StrBootTest{T}) where T

o.invXXXZ̄ .= o.Repl.XZ - o.DGP.XÜ₂ * o.Repl.RparY
o.XȲ .= [o.DGP.Xȳ₁ o.invXXXZ̄]
o.invXXXZ̄ .= o.Repl.invXX * o.invXXXZ̄
o.invXXXZ̄ .= cholldiv!(o.Repl.cholXX, o.invXXXZ̄)
o.ZÜ₂par .= (o.Repl.ZY₂ - o.Repl.XZ'o.DGP.Π̈ ) * o.Repl.RparY
_ȲȲ = o.DGP.γ⃛'o.Repl.XZ - o.DGP.ȳ₁Ü₂ * o.Repl.RparY
o.ȲȲ .= [o.DGP.ȳ₁ȳ₁ _ȲȲ
Expand All @@ -280,14 +280,14 @@ function PrepWRE!(o::StrBootTest{T}) where T

panelcross21!(o.S✻Xu₁, o.DGP.X₁, o.DGP.X₂, o.DGP.u⃛₁, o.info✻)
panelcross21!(o.S✻XU₂par, o.DGP.X₁, o.DGP.X₂, o.Ü₂par, o.info✻)
t✻!(o.invXXS✻Xu₁ , o.DGP.invXX, o.S✻Xu₁ )
t✻!(o.invXXS✻XU₂par, o.DGP.invXX, o.S✻XU₂par)
cholldiv!(o.invXXS✻Xu₁ , o.DGP.cholXX, o.S✻Xu₁ )
cholldiv!(o.invXXS✻XU₂par, o.DGP.cholXX, o.S✻XU₂par)

if o.willfill || o.not2SLS
panelcross!(o.S✻Zperpu₁, o.DGP.Zperp, o.DGP.u⃛₁, o.info✻)
panelcross!(o.S✻ZperpU₂par, o.DGP.Zperp, o.Ü₂par, o.info✻)
t✻!(o.invZperpZperpS✻Zperpu₁, o.DGP.invZperpZperp, o.S✻Zperpu₁)
t✻!(o.invZperpZperpS✻ZperpU₂par, o.DGP.invZperpZperp, o.S✻ZperpU₂par)
cholldiv!(o.invZperpZperpS✻Zperpu₁, o.DGP.cholZperpZperp, o.S✻Zperpu₁)
cholldiv!(o.invZperpZperpS✻ZperpU₂par, o.DGP.cholZperpZperp, o.S✻ZperpU₂par)
if o.NFE>0 && !o.FEboot
crosstabFE!(o, (@view o.CT✻FEU[1:1 ]), [o.DGP.u⃛₁], o.ID✻, o.N✻)
crosstabFE!(o, (@view o.CT✻FEU[2:end]), [o.Ü₂par ], o.ID✻, o.N✻)
Expand Down Expand Up @@ -321,7 +321,7 @@ function PrepWRE!(o::StrBootTest{T}) where T

t✻!(o.S✻XU₂, o.S✻XX, o.DGP.Π̈); o.S✻XU₂ .= o.S✻XY₂ .- o.S✻XU₂; false && o.small && (o.S✻XU₂ .*= o.DGP.m₂)
o.S✻XU₂par .= o.S✻XU₂ * o.Repl.RparY # use this syntax for 3-array x DesignerMatrix
t✻!(o.invXXS✻XU₂, o.DGP.invXX, o.S✻XU₂)
cholldiv!(o.invXXS✻XU₂, o.DGP.cholXX, o.S✻XU₂)
o.invXXS✻XU₂par .= o.invXXS✻XU₂ * o.Repl.RparY # use this syntax for 3-array x DesignerMatrix
if o.willfill || o.not2SLS
t✻!(o.S✻ZperpU₂, o.S✻ZperpX, o.DGP.Π̈); o.S✻ZperpU₂ .= o.S✻ZperpY₂ .- o.S✻ZperpU₂
Expand Down Expand Up @@ -424,7 +424,7 @@ function PrepWRE!(o::StrBootTest{T}) where T

@inbounds for j 0:o.Repl.kZ
if o.Repl.Yendog[j+1]
t✻!(o.negS✻UMZperpX[j+1], o.S⋂XZperpinvZperpZperp, o.S✻ZperpU[j+1]) # S_* diag⁡(U ̈_(∥j) ) Z_⊥ (Z_⊥^' Z_⊥ )^(-1) Z_(⊥g)^' X_(∥g)
t✻!(o.negS✻UMZperpX[j+1], o.S⋂XZperpinvZperpZperp', o.S✻ZperpU[j+1]) # S_* diag⁡(U ̈_(∥j) ) Z_⊥ (Z_⊥^' Z_⊥ )^(-1) Z_(⊥g)^' X_(∥g)
o.negS✻UMZperpX[j+1][o.crosstab⋂✻ind] .-= vec(j>0 ? view(o.S✻⋂XÜ₂par,:,:,j) : view(o.S✻⋂Xu₁,:,:,1))
if o.NFE>0 && !o.FEboot
for i 1:o.DGP.kX
Expand Down Expand Up @@ -735,24 +735,24 @@ function MakeWREStats!(o::StrBootTest{T}, w::Integer) where T

M = Matrix{T}(undef, o.Repl.kZ+1, o.Repl.kZ+1)
@inbounds for b eachindex(axes(o.κWRE,2))
ldiv!(M, bunchkaufman(view(o.YY✻,:,b,:)), view(o.YPXY✻,:,b,:))
ldiv!(M, _cholesky(view(o.YY✻,:,b,:)), view(o.YPXY✻,:,b,:))
o.κWRE[b] = one(T)/(one(T) - real(eigvalsNaN(M)[1]))
end
# view(o.κWRE,1,:,1) .= one(T) ./ (one(T) .- getindex.(real.(eigvalsNaN.(each(invsym(o.YY✻) * o.YPXY✻))), 1))
!iszero(o.fuller) && (o.κWRE .-= o.fuller / (o._Nobs - o.kX))

o.As .= o.κWRE .* view(o.YPXY✻, 2:o.Repl.kZ+1, :, 2:o.Repl.kZ+1) .+ (1 .- o.κWRE) .* view(o.YY✻, 2:o.Repl.kZ+1, :, 2:o.Repl.kZ+1)
invsym!(o.As)
t✻!(view(o.β̈s,:,:,1:1), o.As, o.κWRE .* view(o.YPXY✻, 2:o.Repl.kZ+1, :, 1) .+ (1 .- o.κWRE) .* view(o.YY✻, 2:o.Repl.kZ+1, :, 1))
o.As .= o.κWRE .* view(o.YPXY✻, 2:o.Repl.kZ+1, :, 2:o.Repl.kZ+1) .+ (1 .- o.κWRE) .* view(o.YY✻, 2:o.Repl.kZ+1, :, 2:o.Repl.kZ+1)
o.β̈s[:,:,1:1] .= o.κWRE .* view(o.YPXY✻, 2:o.Repl.kZ+1, :, 1 ) .+ (1 .- o.κWRE) .* view(o.YY✻, 2:o.Repl.kZ+1, :, 1 )
else
HessianFixedkappa!(o, o.δnumer, collect(1:o.Repl.kZ), 0, o.κ, _jk)
@inbounds for i 1:o.Repl.kZ
HessianFixedkappa!(o, view(o.As, 1:i, :, i), collect(1:i), i, o.κ, _jk)
end
symmetrize!(o.As)
invsym!(o.As)
t✻!(view(o.β̈s,:,:,1:1), o.As, view(o.δnumer,:,:,1:1))
o.β̈s[:,:,1:1] = view(o.δnumer,:,:,1:1)
end
cholAs = _cholesky!(o.As)
cholldiv!(cholAs, view(o.β̈s,:,:,1:1))

if o.bootstrapt
if o.robust
Expand All @@ -776,7 +776,7 @@ function MakeWREStats!(o::StrBootTest{T}, w::Integer) where T

if o.bootstrapt
if o.robust # Compute denominator for this WRE test stat
t✻!(o.ARpars, o.As, o.Repl.RRpar')
cholldiv!(o.ARpars, cholAs, Matrix(o.Repl.RRpar'))
t✻!(o.J⋂ARpars[1], o.J⋂s, o.ARpars)
t✻!(o.denomWRE, o.clust[1].multiplier, o.J⋂ARpars[1]', o.J⋂ARpars[1])
for c 2:o.NErrClustCombs
Expand All @@ -787,7 +787,7 @@ function MakeWREStats!(o::StrBootTest{T}, w::Integer) where T
end
else # non-robust
tmp = view([fill(T(-1), 1, o.ncolsv) ; o.β̈s], :, :, 1:1)
o.denomWRE .= (o.Repl.RRpar * o.As * o.Repl.RRpar') .* (tmp'o.YY✻ * tmp) # 2nd half is sig2 of errors
o.denomWRE .= (o.Repl.RRpar * cholldiv(cholAs, Matrix(o.Repl.RRpar'))) .* (tmp'o.YY✻ * tmp) # 2nd half is sig2 of errors
end
if w==1
o.statDenom = o.denomWRE[:,1,:]
Expand All @@ -796,9 +796,9 @@ function MakeWREStats!(o::StrBootTest{T}, w::Integer) where T
if o.sqrt
@storeWtGrpResults!(o.dist, o.numerWRE ./ sqrtNaN.(dropdims(o.denomWRE; dims=3)))
else
invsym!(o.denomWRE)
choldenomWRE = _cholesky!(o.denomWRE)
_numer = view(o.numerWRE,:,:,1:1)
@storeWtGrpResults!(o.dist, dropdims(_numer'o.denomWRE*_numer; dims=3)) # hand-code for 2-dimensional? XXX allocations
@storeWtGrpResults!(o.dist, dropdims(_numer'cholldiv(choldenomWRE,_numer); dims=3)) # hand-code for 2-dimensional? XXX allocations
end
else
@storeWtGrpResults!(o.numer, o.numerWRE)
Expand Down
3 changes: 2 additions & 1 deletion src/WildBootTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ function NoNullUpdate!(o::StrBootTest{T} where T)
else
o.numer[:,1] = o.R * (o.ml ? o.β̈ : iszero(o.κ) ? view(o.M.β̈ ,:,1) : o.M.Rpar * view(o.M.β̈ ,:,1)) - o.r # Analytical Wald numerator; if imposing null then numer[:,1] already equals this. If not, then it's 0 before this
end
o.dist[1] = isone(o.dof) ? o.numer[1] / sqrtNaN(o.statDenom[1]) : o.numer[:,1]'invsym(o.statDenom)*o.numer[:,1]
numer₁ = o.numer[:,1]
o.dist[1] = isone(o.dof) ? o.numer[1] / sqrtNaN(o.statDenom[1]) : numer₁'cholldiv!(_cholesky(o.statDenom), numer₁)
nothing
end

Expand Down
Loading

0 comments on commit 4ba8316

Please sign in to comment.