Skip to content

Commit

Permalink
Fix some things here and there
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 20, 2023
1 parent a7845e2 commit 34b2424
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 28 deletions.
4 changes: 1 addition & 3 deletions src/filtering/smooth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ function smooth!(
_D = size(C_DxD, 1)

# Prediction: t -> t+1
# The following things are there to handle the kronecker case; will be refactored
predict_mean!(x_pred.μ, x_curr.μ, Ah)
predict_cov!(x_pred.Σ, x_curr.Σ, Ah, Qh, C_DxD, C_2DxD, diffusion)

Expand All @@ -104,8 +103,7 @@ function smooth!(

# x_curr.μ .+= G * (x_next.μ .- x_pred.μ) # less allocations:
x_pred.μ .-= x_next.μ
a = D ÷ _D
_matmul!(reshape_no_alloc(x_curr.μ, _D, a), G, reshape_no_alloc(x_pred.μ, _D, a), -1, 1)
_matmul!(x_curr.μ, G, x_pred.μ, -1, 1)

# Joseph-Form:
R = C_3DxD
Expand Down
3 changes: 1 addition & 2 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ function update!(
rdiv!(K, length(_S) == 1 ? _S[1] : cholesky!(_S))

# x_out.μ .= m_p .+ K * (0 .- z)
_matmul!(x_out.μ, K, z)
x_out.μ .= m_p .- x_out.μ
x_out.μ .= m_p .- _matmul!(x_out.μ, K, z)

# M_cache .= I(D) .- mul!(M_cache, K, H)
_matmul!(M_cache, K, H, -1.0, 0.0)
Expand Down
9 changes: 3 additions & 6 deletions src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
# Initialize on u0; taking special care for DynamicalODEProblems
is_secondorder = integ.f isa DynamicalODEFunction
_u = is_secondorder ? view(u.x[2], :) : view(u, :)
E0 = x.Σ.R isa IsoKroneckerProduct ? Proj(0) : Matrix(Proj(0))
init_condition_on!(x, E0, _u, cache)
init_condition_on!(x, Proj(0), _u, cache)
is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t)
integ.stats.nf += 1
E1 = x.Σ.R isa IsoKroneckerProduct ? Proj(1) : Matrix(Proj(1))
init_condition_on!(x, E1, view(du, :), cache)
init_condition_on!(x, Proj(1), view(du, :), cache)

if q < 2
return
Expand All @@ -41,8 +39,7 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
ForwardDiff.jacobian!(ddu, (du, u) -> _f(du, u, p, t), du, u)
end
ddfddu = ddu * view(du, :) + view(dfdt, :)
E2 = x.Σ.R isa IsoKroneckerProduct ? Proj(2) : Matrix(Proj(2))
init_condition_on!(x, E2, ddfddu, cache)
init_condition_on!(x, Proj(2), ddfddu, cache)
if q < 3
return
end
Expand Down
36 changes: 27 additions & 9 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ function check_same_size(A::IKP, B::IKP)
)
end
end
function check_matmul_sizes(A::IKP, B::IKP)
# For A * B
Ad, Bd = A.ldim, B.ldim
An, Am, Bn, Bm = size(A)..., size(B)...
if !(A.ldim == B.ldim) || !(Am == Bnb)
throw(
DimensionMismatch("Matrix multiplication not compatible: A has size ($Ad$An,$Ad$Am), B has size ($Bd$Bn,$Bd$Bm)"),
)
end
end
function check_matmul_sizes(C::IKP, A::IKP, B::IKP)
# For C = A * B
Ad, Bd, Cd = A.ldim, B.ldim, C.ldim
An, Am, Bn, Bm, Cn, Cm = size(A)..., size(B)..., size(C)...
if !(A.ldim == B.ldim == C.ldim) || !(Am == Bn && An == Cn && Bm == Cm)
throw(
DimensionMismatch("Matrix multiplication not compatible: A has size ($Ad$An,$Ad$Am), B has size ($Bd$Bn,$Bd$Bm), C has size ($Cd$Cn,$Cd$Cm)"),
)
end
end

Base.:+(A::IKP, B::IKP) = begin
check_same_size(A, B)
return IsoKroneckerProduct(A.ldim, A.B + B.B)
Expand All @@ -42,21 +63,21 @@ Base.:/(A::IKP, B::IKP) = begin
end
Base.:\(A::IKP, B::IKP) = begin
@assert A.ldim == B.ldim
return IsoKroneckerProduct(A.ldim, A.B / B.B)
return IsoKroneckerProduct(A.ldim, A.B \ B.B)
end

_matmul!(A::IKP, B::IKP, C::IKP) = begin
@assert A.ldim == B.ldim == C.ldim
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B)
return A
end
_matmul!(A::IKP{T}, B::IKP{T}, C::IKP{T}) where {T<:LinearAlgebra.BlasFloat} = begin
@assert A.ldim == B.ldim == C.ldim
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B)
return A
end
_matmul!(A::IKP, B::IKP, C::IKP, alpha::Number, beta::Number) = begin
@assert A.ldim == B.ldim == C.ldim
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B)
return A
end
Expand All @@ -67,12 +88,12 @@ _matmul!(
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert A.ldim == B.ldim == C.ldim
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B, alpha, beta)
return A
end
copy!(A::IKP, B::IKP) = begin
@assert A.ldim == B.ldim
check_same_size(A, B)
copy!(A.B, B.B)
return A
end
Expand All @@ -93,7 +114,6 @@ Found here: https://discourse.julialang.org/t/convert-array-into-matrix-in-place
"""
reshape_no_alloc(a, dims::Tuple) =
invoke(Base._reshape, Tuple{AbstractArray,typeof(dims)}, a, dims)
# reshape_no_alloc(a::AbstractArray, dims::Tuple) = reshape(a, dims)
reshape_no_alloc(a, dims...) = reshape_no_alloc(a, Tuple(dims))
reshape_no_alloc(a::Missing, dims::Tuple) = missing

Expand All @@ -103,8 +123,6 @@ function mul_vectrick!(x::AbstractVecOrMat, A::IsoKroneckerProduct, v::AbstractV

V = reshape_no_alloc(v, (d, length(v) ÷ d))
X = reshape_no_alloc(x, (c, length(x) ÷ c))
# @info "mul_vectrick!" typeof(x) typeof(A) typeof(v)
# @info "mul_vectrick!" typeof(X) typeof(N) typeof(V)
_matmul!(X, N, V)
return x
end
Expand Down
2 changes: 1 addition & 1 deletion src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function estimate_errors!(cache::AbstractODEFilterCache)

# faster:
error_estimate = view(cache.tmp, 1:d)
if R isa Kronecker.AbstractKroneckerProduct
if R isa IsoKroneckerProduct
error_estimate .= sum(abs2, R.B)
else
sum!(abs2, error_estimate', view(R, :, 1:d))
Expand Down
3 changes: 1 addition & 2 deletions src/priors/iwp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ function make_transition_matrices!(cache, prior::IWP, dt)
make_preconditioners!(cache, dt)
# A, Q = preconditioned_discretize(p) # not necessary since it's dt-independent
# Ah = PI * A * P
# @.. Ah.B = PI.B.diag * A.B * P.B.diag'
_matmul!(Ah, PI, _matmul!(Ah, A, P))
# Qh = PI * Q * PI'
fast_X_A_Xt!(Qh, Q, PI)
# @.. Qh.R.B = Q.R.B * PI.B.diag'
end
8 changes: 3 additions & 5 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ function DiffEqBase.build_solution(
true,
Val(isinplace(prob)),
)
q = cache.q

T = eltype(eltype(u))
N = length((size(prob.u0)..., length(u)))

d = length(prob.u0)
uElType = eltype(prob.u0)
D = d

FAC = get_covariance_factorization(alg)
d, q = cache.d, cache.q
D = d * (q+1)

FAC = cache.covariance_factorization
pu_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, d; d, q))
x_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, D; d, q))
pu = StructArray{Gaussian{Vector{uElType},typeof(pu_cov)}}(undef, 0)
Expand Down

0 comments on commit 34b2424

Please sign in to comment.