Skip to content

Commit

Permalink
ManifoldUpdate callback (#79)
Browse files Browse the repository at this point in the history
* Add a simple non-iip `update` in Joseph form

* Implement first simple version of the ManifoldUpdate callback

* Format the files with JuliaFormatter.jl

* Make the ManifoldUpdate callback an IEKF update
  • Loading branch information
nathanaelbosch authored Oct 3, 2021
1 parent 4508747 commit cd0fea4
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ProbNumDiffEq"
uuid = "bf3e78b0-7d74-48a5-b855-9609533b56a5"
authors = ["Nathanael Bosch"]
version = "0.3.0"
version = "0.3.1"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
2 changes: 2 additions & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ include("ieks.jl")
export IEKS, solve_ieks

include("devtools.jl")
include("callbacks.jl")
export ManifoldUpdate

# Do as they do here:
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/v5.61.1/src/OrdinaryDiffEq.jl#L175-L193
Expand Down
31 changes: 31 additions & 0 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
function ManifoldUpdate(residualf; maxiters=100, ϵ₁=1e-25, ϵ₂=1e-15)
condition(u, t, integ) = true

function affect!(integ)
@unpack u = integ
@unpack x, SolProj = integ.cache

f(m) = residualf(SolProj * m)

m, C = x
m_i = copy(m)

local m_i_new, C_i_new
for i in 1:maxiters
z = f(m_i)
J = ForwardDiff.jacobian(f, m_i)
S = X_A_Xt(C, J)

m_i_new, C_i_new = update(x, Gaussian(z .+ (J * (m - m_i)), S), J)

if norm(m_i_new .- m_i) < ϵ₁ && norm(z) < ϵ₂
break
end
m_i = m_i_new
end
copy!(x, Gaussian(m_i_new, C_i_new))

return nothing
end
return DiscreteCallback(condition, affect!)
end
11 changes: 11 additions & 0 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ function update(x::Gaussian, measurement::Gaussian, H::AbstractMatrix)

return Gaussian(m_new, C_new)
end
function update(x::SRGaussian, measurement::SRGaussian, H::AbstractMatrix)
"""In Joseph form"""
m, C = x
z, S = measurement

K = C * H' * inv(S)
m_new = m - K * z
C_new = X_A_Xt(C, (I - K * H))

return Gaussian(m_new, C_new)
end

"""
update!(x_out, x_pred, measurement, H, R=0)
Expand Down
16 changes: 16 additions & 0 deletions test/specific_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,22 @@ end
@test solve(prob, EK0(order=3), callback=Callback()) isa ProbNumDiffEq.ProbODESolution
end

@testset "ManifoldUpdate callback test" begin
# Again: Harmonic Oscillator with condition on E=2
u0 = ones(2)
function harmonic_oscillator(du, u, p, t)
du[1] = u[2]
return du[2] = -u[1]
end
prob = ODEProblem(harmonic_oscillator, u0, (0.0, 100.0))

E(u) = [dot(u, u) - 2]

@test solve(prob, EK0(order=3)) isa ProbNumDiffEq.ProbODESolution
@test solve(prob, EK0(order=3), callback=ManifoldUpdate(E)) isa
ProbNumDiffEq.ProbODESolution
end

@testset "Problem definition with ParameterizedFunctions.jl" begin
f = @ode_def LotkaVolterra begin
dx = a * x - b * x * y
Expand Down

2 comments on commit cd0fea4

@nathanaelbosch
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45992

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" cd0fea451e8566c6d644a9440a26319e13c99d91
git push origin v0.3.1

Please sign in to comment.