Skip to content

Commit

Permalink
Fix parameter handling in gauss adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 12, 2024
1 parent b4b0fed commit 9dba71c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ function (S::GaussIntegrand)(out, t, λ)
y = sol(t)
end
vec_pjac!(out, λ, y, t, S)
out .*= -1
out = recursive_neg!(out)
if S.dgdp !== nothing
S.dgdp(dgdp_cache, y, p, t)
out .+= dgdp_cache
Expand Down
12 changes: 7 additions & 5 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1063,15 +1063,17 @@ _, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts,
reference_sol = ForwardDiff.gradient(
p -> G(p, prob_singular_mm, ts,
sol -> sum(last, sol.u)), vec(p))

for salg in [
QuadratureAdjoint(),
InterpolatingAdjoint(),
BacksolveAdjoint(),
GaussAdjoint()
]
QuadratureAdjoint(),
InterpolatingAdjoint(),
BacksolveAdjoint(),
GaussAdjoint()
]
_, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts,
dgdu_discrete = dg_singular, abstol = 1e-14,
reltol = 1e-14, sensealg = salg,
maxiters = Int(1e6))

@test res'reference_sol rtol=1e-7
end
38 changes: 18 additions & 20 deletions test/parameter_handling.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
using SciMLSensitivity, Lux, Random, Zygote, NonlinearSolve, OrdinaryDiffEq, Test

@static if VERSION v"1.9"
@info "Testing Nonlinear Solve Adjoint with Nested Parameters"
@info "Testing Nonlinear Solve Adjoint with Nested Parameters"

const model_nls = Chain(Dense(2 => 2, tanh), Dense(2 => 2))
ps, st = Lux.setup(Random.default_rng(), model_nls)
const st_nls = st
const model_nls = Chain(Dense(2 => 2, tanh), Dense(2 => 2))
ps, st = Lux.setup(Random.default_rng(), model_nls)
const st_nls = st

x = ones(Float32, 2, 3)
x = ones(Float32, 2, 3)

nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u
nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u

prob = NonlinearProblem(nlprob, zeros(2, 3), ps)
prob = NonlinearProblem(nlprob, zeros(2, 3), ps)

@test_nowarn solve(prob, NewtonRaphson())
@test_nowarn solve(prob, NewtonRaphson())

gs = only(Zygote.gradient(ps) do ps
prob = NonlinearProblem(nlprob, zero.(x), ps)
sol = solve(prob, NewtonRaphson())
return sum(sol.u)
end)
gs = only(Zygote.gradient(ps) do ps
prob = NonlinearProblem(nlprob, zero.(x), ps)
sol = solve(prob, NewtonRaphson())
return sum(sol.u)
end)

@test gs.layer_1.weight !== nothing
@test gs.layer_1.bias !== nothing
@test gs.layer_2.weight !== nothing
@test gs.layer_2.bias !== nothing
end
@test gs.layer_1.weight !== nothing
@test gs.layer_1.bias !== nothing
@test gs.layer_2.weight !== nothing
@test gs.layer_2.bias !== nothing

@info "Testing Gauss Adjoint with Nested Parameters"

Expand All @@ -43,7 +41,7 @@ prob = ODEProblem(odeprob, ones(2, 3), (0.0f0, 1.0f0), ps)

gs = only(Zygote.gradient(ps) do ps
prob = ODEProblem(odeprob, ones(2, 3), (0.0f0, 1.0f0), ps)
sol = solve(prob, Tsit5(); sensealg = GaussAdjoint(; autodiff = ZygoteVJP()))
sol = solve(prob, Tsit5(); sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()))
return sum(last(sol.u))
end)

Expand Down

0 comments on commit 9dba71c

Please sign in to comment.