From 32a1f1058e21db8b0935ce72cf5a5ffc42a84636 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 17 Jan 2018 15:48:30 -0800 Subject: [PATCH] simplified continuous adjoints --- REQUIRE | 1 + src/DiffEqSensitivity.jl | 5 +++-- src/adjoint_sensitivity.jl | 16 +++++++++++++++- test/adjoint.jl | 37 ++++++++++++++++--------------------- test/runtests.jl | 1 + 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/REQUIRE b/REQUIRE index efdcecf68..6acad68f8 100644 --- a/REQUIRE +++ b/REQUIRE @@ -4,3 +4,4 @@ Compat 0.17.0 ForwardDiff DiffEqDiffTools DiffEqCallbacks +QuadGK diff --git a/src/DiffEqSensitivity.jl b/src/DiffEqSensitivity.jl index 5ebf33e6c..d039274f5 100644 --- a/src/DiffEqSensitivity.jl +++ b/src/DiffEqSensitivity.jl @@ -2,7 +2,7 @@ __precompile__() module DiffEqSensitivity -using DiffEqBase, Compat, ForwardDiff, DiffEqDiffTools, DiffEqCallbacks +using DiffEqBase, Compat, ForwardDiff, DiffEqDiffTools, DiffEqCallbacks, QuadGK abstract type SensitivityFunction end @@ -13,5 +13,6 @@ include("adjoint_sensitivity.jl") export extract_local_sensitivities export ODELocalSensitvityFunction, ODELocalSensitivityProblem, SensitivityFunction, - ODEAdjointSensitivityProblem, ODEAdjointProblem, AdjointSensitivityIntegrand + ODEAdjointSensitivityProblem, ODEAdjointProblem, AdjointSensitivityIntegrand, + adjoint_sensitivities end # module diff --git a/src/adjoint_sensitivity.jl b/src/adjoint_sensitivity.jl index cfbd3c8b5..55cf1ff02 100644 --- a/src/adjoint_sensitivity.jl +++ b/src/adjoint_sensitivity.jl @@ -94,7 +94,7 @@ function ODEAdjointProblem(sol,g,t=nothing,dg=nothing, end u0 = zeros(sol.prob.u0)' - y = sol(tspan[1]) # TODO: Has to start at interpolation value! + y = copy(sol(tspan[1])) # TODO: Has to start at interpolation value! λ = similar(u0) sense = ODEAdjointSensitvityFunction(f,uf,pg,u0,jac_config,pg_config,p, λ,alg,discrete, @@ -177,3 +177,17 @@ function (S::AdjointSensitivityIntegrand)(t) out = similar(S.p) S(out,t) end + + +function adjoint_sensitivities(sol,alg,g,t=nothing,dg=nothing; + abstol=1e-6,reltol=1e-3, + iabstol = abstol, ireltol=reltol, + kwargs...) + + adj_prob = ODEAdjointProblem(sol,g,t,dg) + adj_sol = solve(adj_prob,alg,abstol=abstol,reltol=reltol) + integrand = AdjointSensitivityIntegrand(sol,adj_sol) + res,err = quadgk(integrand,sol.prob.tspan[1],sol.prob.tspan[2], + abstol=iabstol,reltol=ireltol) + res +end diff --git a/test/adjoint.jl b/test/adjoint.jl index fc08ffaeb..48bc9e739 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -14,21 +14,22 @@ sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14) t = 0.0:0.5:10.0 # TODO: Add end point handling for callback # g(t,u,i) = (1-u)^2/2, L2 away from 1 -dg(out,u,i) = (out.=1.-u) +dg(out,u,i) = (out.=1.0.-u) + +easy_res = adjoint_sensitivities(sol,Vern9(),dg,t,abstol=1e-14, + reltol=1e-14,iabstol=1e-14,ireltol=1e-12) adj_prob = ODEAdjointProblem(sol,dg,t) adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-14) integrand = AdjointSensitivityIntegrand(sol,adj_sol) -res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-10) +res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-12) -using Plots -gr() -plot(adj_sol,tspan=(0.0,10.0)) +@test norm(res - easy_res) < 1e-14 t_short = 0.5:0.5:10.0 function G(p) tmp_prob = problem_new_parameters(prob,p) - sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t_short,save_start=false) + sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t) A = convert(Array,sol) sum(((1-A).^2)./2) end @@ -42,7 +43,7 @@ res3 = Calculus.gradient(G,[1.5,1.0,3.0]) # Do a continuous adjoint problem # Energy calculation -g(t,u,p) = (out.=(sum(u).^2) ./ 2) +g(t,u,p) = (sum(u).^2) ./ 2 # Gradient of (u1 + u2)^2 / 2 function dg(out,t,u,p) out[1]= u[1] + u[2] @@ -50,29 +51,23 @@ function dg(out,t,u,p) end adj_prob = ODEAdjointProblem(sol,g,nothing,dg) -adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-14) +adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-10) integrand = AdjointSensitivityIntegrand(sol,adj_sol) res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-10) +easy_res = adjoint_sensitivities(sol,Vern9(),g,nothing,dg,abstol=1e-14, + reltol=1e-14,iabstol=1e-14,ireltol=1e-12) -function G(p) - tmp_prob = problem_new_parameters(prob,p) - sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=0.0:0.00001:10.0) - # Trapezoidal rule since quadgk can't autodiff - A = convert(Array,sol) - sum((sum(A,1).^2)./2)* 0.00001 - end -G([1.5,1.0,3.0]) -res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) +@test norm(easy_res - res) < 1e-8 -function G_calc(p) +function G(p) tmp_prob = problem_new_parameters(prob,p) sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14) res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,abstol=1e-14,reltol=1e-10) res end +res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0]) +res3 = Calculus.gradient(G,[1.5,1.0,3.0]) -res3 = Calculus.gradient(G_calc,[1.5,1.0,3.0]) - -@test norm(res' - res2) < 1e-4 +@test norm(res' - res2) < 1e-8 @test norm(res' - res3) < 1e-6 diff --git a/test/runtests.jl b/test/runtests.jl index 960619eeb..9c882bf24 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,3 +2,4 @@ using DiffEqSensitivity using Base.Test @testset "Local Sensitivity" begin include("local.jl") end +@testset "Adjoint Sensitivity" begin include("adjoint.jl") end