Skip to content

Commit

Permalink
simplified continuous adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 17, 2018
1 parent c6ecce2 commit 32a1f10
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 24 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Compat 0.17.0
ForwardDiff
DiffEqDiffTools
DiffEqCallbacks
QuadGK
5 changes: 3 additions & 2 deletions src/DiffEqSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ __precompile__()

module DiffEqSensitivity

using DiffEqBase, Compat, ForwardDiff, DiffEqDiffTools, DiffEqCallbacks
using DiffEqBase, Compat, ForwardDiff, DiffEqDiffTools, DiffEqCallbacks, QuadGK

abstract type SensitivityFunction end

Expand All @@ -13,5 +13,6 @@ include("adjoint_sensitivity.jl")
export extract_local_sensitivities

export ODELocalSensitvityFunction, ODELocalSensitivityProblem, SensitivityFunction,
ODEAdjointSensitivityProblem, ODEAdjointProblem, AdjointSensitivityIntegrand
ODEAdjointSensitivityProblem, ODEAdjointProblem, AdjointSensitivityIntegrand,
adjoint_sensitivities
end # module
16 changes: 15 additions & 1 deletion src/adjoint_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function ODEAdjointProblem(sol,g,t=nothing,dg=nothing,
end

u0 = zeros(sol.prob.u0)'
y = sol(tspan[1]) # TODO: Has to start at interpolation value!
y = copy(sol(tspan[1])) # TODO: Has to start at interpolation value!
λ = similar(u0)
sense = ODEAdjointSensitvityFunction(f,uf,pg,u0,jac_config,pg_config,p,
λ,alg,discrete,
Expand Down Expand Up @@ -177,3 +177,17 @@ function (S::AdjointSensitivityIntegrand)(t)
out = similar(S.p)
S(out,t)
end


function adjoint_sensitivities(sol,alg,g,t=nothing,dg=nothing;
abstol=1e-6,reltol=1e-3,
iabstol = abstol, ireltol=reltol,
kwargs...)

adj_prob = ODEAdjointProblem(sol,g,t,dg)
adj_sol = solve(adj_prob,alg,abstol=abstol,reltol=reltol)
integrand = AdjointSensitivityIntegrand(sol,adj_sol)
res,err = quadgk(integrand,sol.prob.tspan[1],sol.prob.tspan[2],
abstol=iabstol,reltol=ireltol)
res
end
37 changes: 16 additions & 21 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@ sol = solve(prob,Vern9(),abstol=1e-14,reltol=1e-14)

t = 0.0:0.5:10.0 # TODO: Add end point handling for callback
# g(t,u,i) = (1-u)^2/2, L2 away from 1
dg(out,u,i) = (out.=1.-u)
dg(out,u,i) = (out.=1.0.-u)

easy_res = adjoint_sensitivities(sol,Vern9(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

adj_prob = ODEAdjointProblem(sol,dg,t)
adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-14)
integrand = AdjointSensitivityIntegrand(sol,adj_sol)
res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-10)
res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-12)

using Plots
gr()
plot(adj_sol,tspan=(0.0,10.0))
@test norm(res - easy_res) < 1e-14

t_short = 0.5:0.5:10.0
function G(p)
tmp_prob = problem_new_parameters(prob,p)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t_short,save_start=false)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t)
A = convert(Array,sol)
sum(((1-A).^2)./2)
end
Expand All @@ -42,37 +43,31 @@ res3 = Calculus.gradient(G,[1.5,1.0,3.0])
# Do a continuous adjoint problem

# Energy calculation
g(t,u,p) = (out.=(sum(u).^2) ./ 2)
g(t,u,p) = (sum(u).^2) ./ 2
# Gradient of (u1 + u2)^2 / 2
function dg(out,t,u,p)
out[1]= u[1] + u[2]
out[2]= u[1] + u[2]
end

adj_prob = ODEAdjointProblem(sol,g,nothing,dg)
adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-14)
adj_sol = solve(adj_prob,Vern9(),abstol=1e-14,reltol=1e-10)
integrand = AdjointSensitivityIntegrand(sol,adj_sol)
res,err = quadgk(integrand,0.0,10.0,abstol=1e-14,reltol=1e-10)

easy_res = adjoint_sensitivities(sol,Vern9(),g,nothing,dg,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12)

function G(p)
tmp_prob = problem_new_parameters(prob,p)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=0.0:0.00001:10.0)
# Trapezoidal rule since quadgk can't autodiff
A = convert(Array,sol)
sum((sum(A,1).^2)./2)* 0.00001
end
G([1.5,1.0,3.0])
res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0])
@test norm(easy_res - res) < 1e-8

function G_calc(p)
function G(p)
tmp_prob = problem_new_parameters(prob,p)
sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14)
res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,abstol=1e-14,reltol=1e-10)
res
end
res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0])
res3 = Calculus.gradient(G,[1.5,1.0,3.0])

res3 = Calculus.gradient(G_calc,[1.5,1.0,3.0])

@test norm(res' - res2) < 1e-4
@test norm(res' - res2) < 1e-8
@test norm(res' - res3) < 1e-6
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ using DiffEqSensitivity
using Base.Test

@testset "Local Sensitivity" begin include("local.jl") end
@testset "Adjoint Sensitivity" begin include("adjoint.jl") end

0 comments on commit 32a1f10

Please sign in to comment.