Skip to content

Commit

Permalink
Optimizations in new granular WRE
Browse files Browse the repository at this point in the history
  • Loading branch information
droodman committed Mar 7, 2023
1 parent fd6e9d8 commit 16b699d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 77 deletions.
6 changes: 3 additions & 3 deletions src/WRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ function HessianFixedkappa!(o::StrBootTest{T}, dest::AbstractMatrix{T}, is::Vect
coldotminus!(_dest, 1, o.CT✻FEUv, o.invFEwtCT✻FEUv)
end

dest[row:row,:] .*= κ; dest[row:row,:] .+= (1-κ) .* _dest
view(dest,row:row,:) .*= κ; view(dest,row:row,:) .+= (1-κ) .* _dest # "view(dest,row:row,:)" avoids allocations where dest[row:row,:] doesn't
end
end

Expand Down Expand Up @@ -516,7 +516,7 @@ function Filling!(o::StrBootTest{T}, dest::AbstractMatrix{T}, i::Int64, β̈s::A
o.NFE>0 && !o.FEboot &&
(o.S✻UMZperp .+= view(o.invFEwtCT✻FEUv, o.FEID[g]:o.FEID[g], :)) # CT_(*,FE) (U ̈_(∥j) ) (S_FE S_FE^' )^(-1) S_FE

dest[g:g,:] .= o.PXY✻ .* o.DGP.ȳ₁[g]
t✻!(view(dest,g:g,:), o.DGP.ȳ₁[g], o.PXY✻)
coldotminus!(dest, g, o.PXY✻, o.S✻UMZperp)
end
else
Expand Down Expand Up @@ -554,7 +554,7 @@ function Filling!(o::StrBootTest{T}, dest::AbstractMatrix{T}, i::Int64, β̈s::A
o.NFE>0 && !o.FEboot &&
(o.S✻UMZperp .+= view(o.invFEwtCT✻FEUv, o.FEID[g]:o.FEID[g], :)) # CT_(*,FE) (U ̈_(∥j) ) (S_FE S_FE^' )^(-1) S_FE

dest[g:g,:] .-= o.PXY✻ .* (o.Z̄[g,j] * _β̈ )
t✻minus!(view(dest,g:g,:), o.Z̄[g,j], o.PXY✻, _β̈ )
coldotplus!(dest, g, o.PXY✻, o.S✻UMZperp)
else
coldotminus!(dest, g, o.PXY✻, o.Z̄[g,j] * _β̈)
Expand Down
2 changes: 1 addition & 1 deletion src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ function InitTestDenoms!(o::StrEstimator{T}, parent::StrBootTest{T}) where T

if parent.robust && parent.NFE>0 && !(parent.FEboot || parent.scorebs) && parent.granular < parent.NErrClustCombs # make first factor of second term of (64) for c=⋂ (c=1)
!isdefined(o, :WXAR) && (o.WXAR = o.XAR) # XXX simplify
o.CT_XAR = [crosstabFEt(parent, view(o.WXAR,:,d), parent.info⋂) for d 1:parent.dof]
o.CT_XAR = crosstabFE(parent, o.WXAR, parent.info⋂)
end
end
nothing
Expand Down
4 changes: 2 additions & 2 deletions src/nonWRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function _MakeInterpolables!(o::StrBootTest{T}, thisr::AbstractVector) where T
if o.NFE>0 && !o.FEboot
tmp = o.invFEwt .* dropdims(crosstabFE(o, o.DGP.ü₁[1+_jk], o.info✻); dims=3)
@inbounds for d 1:o.dof
K[d] .+= o.M.CT_XAR[d] * tmp
t✻plus!(K[d], view(o.M.CT_XAR,:,:,d)', tmp)
end
end
@inbounds for d 1:o.dof
Expand Down Expand Up @@ -254,7 +254,7 @@ function MakeNumerAndJ!(o::StrBootTest{T}, w::Integer, _jk::Bool, r::AbstractVec
end

function MakeNonWRELoop1!(o::StrBootTest, tmp::Matrix, w::Integer)
@inbounds #=Threads.@threads=# for k o.ncolsv:-1:1
@inbounds #= #=Threads.@threads=# =# for k o.ncolsv:-1:1
@inbounds for i 1:o.dof
for j 1:i
tmp[j,i] = o.denom[i,j][k] # fill upper triangle, which is all invsym() looks at
Expand Down
2 changes: 1 addition & 1 deletion src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ mutable struct StrEstimator{T<:AbstractFloat}
invXX::Matrix{T}; Y₂::Matrix{T}; X₂::Matrix{T}; invH::Matrix{T}
y₁par::Vector{T}; Xy₁par::Vector{T}
A::Matrix{T}; Zpar::Matrix{T}; Zperp::Matrix{T}; X₁::Matrix{T}
WXAR::Matrix{T}; CT_XAR::Vector{Matrix{T}}
WXAR::Matrix{T}; CT_XAR::Array{T,3}

S✻XX::Array{T,3}; XinvHjk::Vector{Matrix{T}}; invMjk::Vector{Matrix{T}}; invMjkv::Vector{T}; XXt1jk::Matrix{T}; t₁::Vector{T}

Expand Down
101 changes: 36 additions & 65 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ end
@inline function t✻!(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{T}, C::AbstractVecOrMat{T}) where T
mul!(A, B, C)
end
@inline function t✻!(A::AbstractVecOrMat{T}, b::T, C::AbstractVecOrMat{T}) where T
@tturbo for i indices((A,C),1), j indices((A,C),2)
A[i,j] = b * C[i,j]
end
nothing
end
function t✻(A::AbstractVecOrMat{T}, B::AbstractVector{T}) where T
dest = Vector{T}(undef, size(A,1))
mul!(dest, A, B)
Expand Down Expand Up @@ -240,6 +246,18 @@ function t✻minus!(A::AbstractMatrix{T}, B::AbstractVecOrMat{T}, C::AbstractMat
end
nothing
end
function t✻minus!(A::AbstractMatrix{T}, c::T, B::AbstractVecOrMat{T}, C::AbstractMatrix{T}) where T # add B*C to A in place
if length(B)>0 && length(C)>0
@tturbo warn_check_args=false for i indices((A,B),1), k indices((A,C),2)
Aᵢₖ = zero(T)
for j indices((B,C),(2,1))
Aᵢₖ += B[i,j] * C[j,k]
end
A[i,k] -= c * Aᵢₖ
end
end
nothing
end
function t✻minus!(A::AbstractVector{T}, B::AbstractVecOrMat{T}, C::AbstractVector{T}) where T # add B*C to A in place
if length(B)>0 && length(C)>0
@tturbo warn_check_args=false for i eachindex(axes(A,1),axes(B,1))
Expand Down Expand Up @@ -591,44 +609,32 @@ function panelcoldotminus!(dest::AbstractMatrix{T}, X::AbstractMatrix{T}, Y::Abs
end


# cross-tab sum of a column vector w.r.t. given panel info and fixed-effect var
# crosstab of a column vector w.r.t. given panel info and fixed-effect var
# one row per FE, one col per other grouping
# handling multiple columns in v
# dimensions: (FEs,entries of info, cols of v)
# this facilitates reshape() to 2-D array in which results for each col of v are stacked vertically
function crosstabFE!(o::StrBootTest{T}, dest::Array{T,3}, v::AbstractVecOrMat{T}, info::Vector{UnitRange{Int64}}) where T
if o.haswt
vw = v .* o.sqrtwt
if nrows(info)>0
fill!(dest, zero(T))
@inbounds Threads.@threads for i axes(info,1)
FEIDi = view(o._FEID, info[i])
vi = @view vw[info[i],:]
@inbounds for j axes(FEIDi,1)
dest[FEIDi[j],i,:] += @view vi[j,:]
vw = o.haswt ? v .* o.sqrtwt : v
if nrows(info)>0
fill!(dest, zero(T))
@inbounds Threads.@threads for i eachindex(axes(info,1))
infoi = info[i]
@inbounds @fastmath for infoij infoi
FEIDij = o._FEID[infoij]
for k eachindex(axes(vw,2))
dest[FEIDij,i,k] += vw[infoij,k]
end
end
else # "robust" case, no clustering
@inbounds Threads.@threads for i axes(o._FEID,1)
dest[o._FEID[i],i,:] .= @view vw[i,:]
end
end
else
if nrows(info)>0
fill!(dest, zero(T))
@inbounds Threads.@threads for i axes(info,1)
FEIDi = view(o._FEID, info[i])
vi = @view v[info[i],:]
@inbounds for j axes(FEIDi,1)
dest[FEIDi[j],i,:] += @view vi[j,:]
end
end
else # "robust" case, no clustering
@inbounds Threads.@threads for i axes(o._FEID,1)
dest[o._FEID[i],i,:] .= @view v[i,:]
else # "robust" case, no clustering
@inbounds Threads.@threads for i eachindex(axes(o._FEID,1))
FEIDi = o._FEID[i]
@inbounds for k eachindex(axes(vw,2))
dest[FEIDi,i,k] .= vw[i,k]
end
end
end
end
nothing
end
function crosstabFE(o::StrBootTest{T}, v::AbstractVecOrMat{T}, info::Vector{UnitRange{Int64}}) where T
Expand All @@ -637,41 +643,6 @@ function crosstabFE(o::StrBootTest{T}, v::AbstractVecOrMat{T}, info::Vector{Unit
dest
end

# same, transposed
function crosstabFEt(o::StrBootTest{T}, v::AbstractVector{T}, info::Vector{UnitRange{Int64}}) where T
dest = zeros(T, nrows(info), o.NFE)
if o.haswt
vw = v .* o.sqrtwt
if nrows(info)>0
@inbounds Threads.@threads for i axes(info,1)
FEIDi = @view o._FEID[info[i]]
vi = @view vw[info[i]]
@inbounds for j eachindex(vi, FEIDi)
dest[i,FEIDi[j]] += vi[j]
end
end
else # "robust" case, no clustering
@inbounds Threads.@threads for i eachindex(v,o._FEID)
dest[i,o._FEID[i]] = vw[i]
end
end
else
if nrows(info)>0
@inbounds Threads.@threads for i axes(info,1)
FEIDi = @view o._FEID[info[i]]
vi = @view v[info[i]]
@inbounds for j eachindex(vi, FEIDi)
dest[i,FEIDi[j]] += vi[j]
end
end
else # "robust" case, no clustering
@inbounds Threads.@threads for i eachindex(v,o._FEID)
dest[i,o._FEID[i]] = v[i]
end
end
end
dest
end

# partial any fixed effects out of a data matrix
function partialFE!(o::StrBootTest{T}, In::AbstractVector{T}) where T
Expand Down Expand Up @@ -706,7 +677,7 @@ end
function partialFE!(o::StrBootTest{T}, In::AbstractMatrix{T}) where T
if length(In)>0
if o.haswt
Threads.@threads for j eachindex(axes(In,2))
@inbounds #=Threads.@threads=# for j eachindex(axes(In,2))
@inbounds @fastmath for f o.FEs
fis = f.is; wt = f.wtvec; sqrtwt = f.sqrtwt
s = zero(T)
Expand All @@ -719,7 +690,7 @@ function partialFE!(o::StrBootTest{T}, In::AbstractMatrix{T}) where T
end
end
else
Threads.@threads for j eachindex(axes(In,2))
@inbounds #=Threads.@threads=# for j eachindex(axes(In,2))
@inbounds @fastmath for f o.FEs
fis = f.is
s = zero(T)
Expand Down
6 changes: 1 addition & 5 deletions test/unittests.log
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,6 @@ CI = [-0.4649 2.001]

ivreghdfe wage ttl_exp collgrad tenure (occupation = union married) [aw=hours] if grade<., liml cluster(industry) absorb(age)
boottest tenure
t(11) = 0.6068
p = 0.4935
CI = [-0.03753 0.05658]

boottest tenure, jk
t(11) = 0.6068
p = 0.5145
Expand Down Expand Up @@ -361,7 +357,7 @@ boottest merit, nogr reps(9999) nonull bootcluster(individual)
boottest merit, nogr reps(9999) nonull bootcluster(individual) matsize(.1)
t(41) = 6.7127
p = 0.4918
CI = [-2.814 1.285]
CI = [-2.816 1.285]

t(41) = 6.7127
p = 0.0000
Expand Down

0 comments on commit 16b699d

Please sign in to comment.