Skip to content

Commit

Permalink
Merge pull request #55 from JuliaDiffEq/adj_tests
Browse files Browse the repository at this point in the history
speed up tests by using Tsit5
  • Loading branch information
ChrisRackauckas authored Apr 9, 2019
2 parents 59ee737 + 9bb7fc1 commit 8e763fa
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
53 changes: 32 additions & 21 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ end a b c

p = [1.5,1.0,3.0]
prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)
sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14)
probb = ODEProblem(fb,[1.0;1.0],(0.0,10.0),p)
solb = solve(probb,Vern9(),abstol=1e-14,reltol=1e-14)
solb = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14)
sol_end = solve(probb,Tsit5(),abstol=1e-14,reltol=1e-14,
save_everystep=false,save_start=false)

# Do a discrete adjoint problem
println("Calculate discrete adjoint sensitivities")
Expand All @@ -26,27 +28,27 @@ function dg(out,u,p,t,i)
(out.=2.0.-u)
end

easy_res = adjoint_sensitivities(sol,Vern9(),dg,t,abstol=1e-14,
easy_res = adjoint_sensitivities(sol,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)
easy_res2 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
easy_res2 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(quad=true,backsolve=false))
easy_res3 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
easy_res3 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(quad=false,backsolve=false))
easy_res4 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
easy_res4 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(backsolve=true))
easy_res5 = adjoint_sensitivities(sol,Kvaerno5(nlsolve=NLAnderson(), smooth_est=false),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(backsolve=true))
easy_res6 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
easy_res6 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,
sensealg=SensitivityAlg(checkpointing=true,quad=true),
checkpoints=sol.t[1:5:end])
easy_res7 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
easy_res7 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,
sensealg=SensitivityAlg(checkpointing=true,quad=false),
checkpoints=sol.t[1:5:end])

adj_prob = ODEAdjointProblem(sol,dg,t)
adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-14)
adj_sol = solve(adj_prob,Tsit5(),abstol=1e-14,reltol=1e-14)
integrand = AdjointSensitivityIntegrand(sol,adj_sol)
res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-12)

Expand All @@ -58,17 +60,26 @@ res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-12)
@test isapprox(res, easy_res6, rtol = 1e-9)
@test isapprox(res, easy_res7, rtol = 1e-9)

easy_res8 = adjoint_sensitivities(solb,Vern9(),dg,t,abstol=1e-14,
println("Calculate adjoint sensitivities ")

easy_res8 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,
save_everystep=false,save_start=false,
sensealg=SensitivityAlg(backsolve=true))

@test isapprox(res, easy_res8, rtol = 1e-9)

end_only_res = adjoint_sensitivities(sol_end,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,
save_everystep=false,save_start=false,
sensealg=SensitivityAlg(backsolve=true))

@test isapprox(res, end_only_res, rtol = 1e-9)

println("Calculate adjoint sensitivities from autodiff & numerical diff")
function G(p)
tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t)
sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14,saveat=t)
A = convert(Array,sol)
sum(((2 .- A).^2)./2)
end
Expand Down Expand Up @@ -130,7 +141,7 @@ println("Calculate adjoint sensitivities from autodiff & numerical diff")
function G(p)
tmp_prob = remake(prob,u0=eltype(p).(prob.u0),p=p,
tspan=eltype(p).(prob.tspan))
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14)
sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14)
res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,atol=1e-14,rtol=1e-10)
res
end
Expand All @@ -144,8 +155,8 @@ res3 = Calculus.gradient(G,[1.5,1.0,3.0])
f = (du, u, p, t) -> du .= 0
p = zeros(3); u = zeros(50)
prob = ODEProblem(f,u,(0.0,10.0),p)
sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)
@test_nowarn res = adjoint_sensitivities(sol,Vern9(),dg,t,abstol=1e-14,
sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14)
@test_nowarn res = adjoint_sensitivities(sol,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

println("Adjoints of u0")
Expand All @@ -154,29 +165,29 @@ function dg(out,u,p,t,i)
out .= 1 .- u
end

ū0,adj = adjoint_sensitivities_u0(sol,Vern9(),dg,t,abstol=1e-14,
ū0,adj = adjoint_sensitivities_u0(sol,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

adj2 = adjoint_sensitivities(sol,Vern9(),dg,t,abstol=1e-14,
adj2 = adjoint_sensitivities(sol,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

ū02,adj3 = adjoint_sensitivities_u0(sol,Vern9(),dg,t,abstol=1e-14,
ū02,adj3 = adjoint_sensitivities_u0(sol,Tsit5(),dg,t,abstol=1e-14,
sensealg=SensitivityAlg(backsolve=true),
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

ū03,adj4 = adjoint_sensitivities_u0(sol,Vern9(),dg,t,abstol=1e-14,
ū03,adj4 = adjoint_sensitivities_u0(sol,Tsit5(),dg,t,abstol=1e-14,
save_everystep=false, save_start=false,
sensealg=SensitivityAlg(backsolve=true),
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

ū03,adj4 = adjoint_sensitivities_u0(sol,Vern9(),dg,t,abstol=1e-14,
ū03,adj4 = adjoint_sensitivities_u0(sol,Tsit5(),dg,t,abstol=1e-14,
save_everystep=false, save_start=false,
sensealg=SensitivityAlg(backsolve=true),
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

res = ForwardDiff.gradient(prob.u0) do u0
tmp_prob = remake(prob,u0=u0)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t)
sol = solve(tmp_prob,Tsit5(),abstol=1e-14,reltol=1e-14,saveat=t)
A = convert(Array,sol)
sum(((1 .- A).^2)./2)
end
Expand Down
8 changes: 4 additions & 4 deletions test/local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prob = ODELocalSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p)
probInpl = ODELocalSensitivityProblem(LotkaVolt!,[1.0;1.0],(0.0,10.0),p)
probnoad = ODELocalSensitivityProblem(LotkaVolt!,[1.0;1.0],(0.0,10.0),p,
SensitivityAlg(autodiff=false))
sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)
sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14)
@test_broken solInpl = solve(probInpl,KenCarp4(),abstol=1e-14,reltol=1e-14)
@test_broken solInpl2 = solve(probInpl,Rodas4(autodiff=false),abstol=1e-14,reltol=1e-14)
solInpl = solve(probInpl,KenCarp4(autodiff=false),abstol=1e-14,reltol=1e-14)
Expand All @@ -41,7 +41,7 @@ dc = sol[sol.prob.f.numindvar*3+1:sol.prob.f.numindvar*4,:]
sense_res1 = [da[:,end] db[:,end] dc[:,end]]

prob = ODELocalSensitivityProblem(f.f,[1.0;1.0],(0.0,10.0),p,SensitivityAlg(autojacvec=true))
sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)
sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14)
x = sol[1:sol.prob.f.numindvar,:]

# Get the sensitivities
Expand All @@ -54,7 +54,7 @@ sense_res2 = [da[:,end] db[:,end] dc[:,end]]

function test_f(p)
prob = ODEProblem(f,eltype(p).([1.0,1.0]),(0.0,10.0),p)
solve(prob,Vern9(),abstol=1e-14,reltol=1e-14,save_everystep=false)[end]
solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14,save_everystep=false)[end]
end

p = [1.5,1.0,3.0]
Expand All @@ -74,7 +74,7 @@ function f2(du,u,p,t)
end
p = [1.5,1.0,3.0]
prob = ODELocalSensitivityProblem(f2,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)
sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14)
res = sol[1:sol.prob.f.numindvar,:]

# Get the sensitivities
Expand Down

0 comments on commit 8e763fa

Please sign in to comment.