Skip to content

Commit

Permalink
Merge pull request #1033 from SciML/ap/fix_reversediff
Browse files Browse the repository at this point in the history
Fix Test Failures
  • Loading branch information
avik-pal authored Mar 14, 2024
2 parents 4a86cef + 3471ad7 commit 6e48a61
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 27 deletions.
5 changes: 5 additions & 0 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ struct FakeIntegrator{uType, P, tType, tprevType}
tprev::tprevType
end

function Base.getproperty(fi::FakeIntegrator, s::Symbol)
s === :tdir && return sign(fi.t - fi.tprev)
return getfield(fi, s)
end

struct CallbackSensitivityFunction{fType, Alg <: AbstractOverloadingSensitivityAlgorithm,
C <: AdjointDiffCache, pType} <: SensitivityFunction
f::fType
Expand Down
2 changes: 1 addition & 1 deletion src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ function (S::GaussIntegrand)(out, t, λ)
y = sol(t)
end
vec_pjac!(out, λ, y, t, S)
out .*= -1
out = recursive_neg!(out)
if S.dgdp !== nothing
S.dgdp(dgdp_cache, y, p, t)
out .+= dgdp_cache
Expand Down
12 changes: 7 additions & 5 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1063,15 +1063,17 @@ _, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts,
reference_sol = ForwardDiff.gradient(
p -> G(p, prob_singular_mm, ts,
sol -> sum(last, sol.u)), vec(p))

for salg in [
QuadratureAdjoint(),
InterpolatingAdjoint(),
BacksolveAdjoint(),
GaussAdjoint()
]
QuadratureAdjoint(),
InterpolatingAdjoint(),
BacksolveAdjoint(),
GaussAdjoint()
]
_, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts,
dgdu_discrete = dg_singular, abstol = 1e-14,
reltol = 1e-14, sensealg = salg,
maxiters = Int(1e6))

@test res'reference_sol rtol=1e-7
end
38 changes: 18 additions & 20 deletions test/parameter_handling.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
using SciMLSensitivity, Lux, Random, Zygote, NonlinearSolve, OrdinaryDiffEq, Test

@static if VERSION v"1.9"
@info "Testing Nonlinear Solve Adjoint with Nested Parameters"
@info "Testing Nonlinear Solve Adjoint with Nested Parameters"

const model_nls = Chain(Dense(2 => 2, tanh), Dense(2 => 2))
ps, st = Lux.setup(Random.default_rng(), model_nls)
const st_nls = st
const model_nls = Chain(Dense(2 => 2, tanh), Dense(2 => 2))
ps, st = Lux.setup(Random.default_rng(), model_nls)
const st_nls = st

x = ones(Float32, 2, 3)
x = ones(Float32, 2, 3)

nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u
nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u

prob = NonlinearProblem(nlprob, zeros(2, 3), ps)
prob = NonlinearProblem(nlprob, zeros(2, 3), ps)

@test_nowarn solve(prob, NewtonRaphson())
@test_nowarn solve(prob, NewtonRaphson())

gs = only(Zygote.gradient(ps) do ps
prob = NonlinearProblem(nlprob, zero.(x), ps)
sol = solve(prob, NewtonRaphson())
return sum(sol.u)
end)
gs = only(Zygote.gradient(ps) do ps
prob = NonlinearProblem(nlprob, zero.(x), ps)
sol = solve(prob, NewtonRaphson())
return sum(sol.u)
end)

@test gs.layer_1.weight !== nothing
@test gs.layer_1.bias !== nothing
@test gs.layer_2.weight !== nothing
@test gs.layer_2.bias !== nothing
end
@test gs.layer_1.weight !== nothing
@test gs.layer_1.bias !== nothing
@test gs.layer_2.weight !== nothing
@test gs.layer_2.bias !== nothing

@info "Testing Gauss Adjoint with Nested Parameters"

Expand All @@ -43,7 +41,7 @@ prob = ODEProblem(odeprob, ones(2, 3), (0.0f0, 1.0f0), ps)

gs = only(Zygote.gradient(ps) do ps
prob = ODEProblem(odeprob, ones(2, 3), (0.0f0, 1.0f0), ps)
sol = solve(prob, Tsit5(); sensealg = GaussAdjoint(; autodiff = ZygoteVJP()))
sol = solve(prob, Tsit5(); sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()))
return sum(last(sol.u))
end)

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ end

if GROUP == "DiffEq"
@testset "DiffEq" begin
activate_gpu_env()
activate_diffeq_env()
@time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl")
end
end
Expand Down

0 comments on commit 6e48a61

Please sign in to comment.