Skip to content

Commit

Permalink
Refactor and shorten the initialization to use condition_on!
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 11, 2023
1 parent e6f98fa commit 51b6b02
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 68 deletions.
27 changes: 3 additions & 24 deletions src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,10 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
# Initialize on u0; taking special care for DynamicalODEProblems
is_secondorder = integ.f isa DynamicalODEFunction
_u = is_secondorder ? view(u.x[2], :) : view(u, :)
# condition_on!(x, Proj(0), _u, cache)
begin
H = Proj(0)
measurement.μ .= H*x.μ - _u
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
condition_on!(x, Proj(0), _u, cache)
is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t)
integ.stats.nf += 1
# condition_on!(x, Proj(1), view(du, :), cache)
begin
H = Proj(1)
measurement.μ .= H*x.μ - view(du, :)
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
condition_on!(x, Proj(1), view(du, :), cache)

if q < 2
return
Expand All @@ -53,14 +39,7 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
ForwardDiff.jacobian!(ddu, (du, u) -> _f(du, u, p, t), du, u)
end
ddfddu = ddu * view(du, :) + view(dfdt, :)
# condition_on!(x, Proj(2), ddfddu, cache)
begin
H = Proj(2)
measurement.μ .= H*x.μ - ddfddu
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
condition_on!(x, Proj(2), ddfddu, cache)
if q < 3
return
end
Expand Down
46 changes: 6 additions & 40 deletions src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,44 +74,10 @@ function condition_on!(
data::AbstractVector,
cache,
)
@unpack m_tmp, K1, x_tmp, C_DxD = cache
S = m_tmp.Σ
covcache = x_tmp.Σ
Mcache = cache.C_DxD

fast_X_A_Xt!(S, x.Σ, H)
# @assert isdiag(Matrix(S))
S_diag = diag(S)
if any(iszero.(S_diag)) # could happen with a singular mass-matrix
S_diag .+= 1e-20
end

# _matmul!(K1, x.Σ.R', _matmul!(cache.C_Dxd, x.Σ.R, H'))
# K = K1 ./= S_diag'

_K = x.Σ.R' * x.Σ.R * H'
@assert all(S_diag .== 1)
K = _K

# x.μ .+= K*(data - z)
datadiff = _matmul!(data, H, x.μ, -1, 1)
_matmul!(x.μ, K, datadiff, 1, 1)

D = length(x.μ)
# _matmul!(Mcache, K, H, -1, 0)
# @inbounds @simd ivdep for i in 1:D
# Mcache[i, i] += 1
# end

d, q1 = size(H.A, 1), size(x.Σ.R.B, 1)
_I = kronecker(I(d), I(q1))
KH = K*H
@assert _I.A == KH.A
@. KH.B = _I.B - KH.B
M = KH

fast_X_A_Xt!(x_tmp.Σ, x.Σ, M)
copy!(x.Σ.R.A, covcache.R.A)
copy!(x.Σ.R.B, covcache.R.B)
return nothing
@unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, measurement = cache
_matmul!(measurement.μ, H, x.μ)
measurement.μ .-= data
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
5 changes: 1 addition & 4 deletions src/initialization/taylormode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ function initial_update!(integ, cache, init::TaylorModeInit)
end

H = f.mass_matrix * Proj(o)
measurement.μ .= H*x.μ - df
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
condition_on!(x, H, df, cache)
end
end

Expand Down

0 comments on commit 51b6b02

Please sign in to comment.