Skip to content

Commit

Permalink
Merge pull request #57 from YingboMa/myb/compile
Browse files Browse the repository at this point in the history
Improve compile time
  • Loading branch information
ChrisRackauckas authored Apr 11, 2019
2 parents 8e763fa + 018599c commit 21540e3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
35 changes: 19 additions & 16 deletions src/adjoint_sensitivity.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux.Tracker: gradient

struct ODEAdjointSensitivityFunction{dgType,rateType,uType,F,J,PJ,UF,PF,G,JC,GC,A,SType,DG,MM,TJ,PJT,PJC,CP,INT} <: SensitivityFunction
struct ODEAdjointSensitivityFunction{dgType,rateType,uType,F,J,PJ,UF,PF,G,JC,GC,A,DG,MM,TJ,PJT,PJC,CP,SType,INT} <: SensitivityFunction
f::F
jac::J
paramjac::PJ
Expand All @@ -26,11 +26,11 @@ struct ODEAdjointSensitivityFunction{dgType,rateType,uType,F,J,PJ,UF,PF,G,JC,GC,
integrator::INT
end

function ODEAdjointSensitivityFunction(f,jac,paramjac,uf,pf,g,u0,
@noinline function ODEAdjointSensitivityFunction(f,jac,paramjac,uf,pf,g,u0,
jac_config,g_grad_config,paramjac_config,
p,f_cache,alg,discrete,y,sol,dg,mm,checkpoints)
numparams = length(p)
numindvar = length(u0)
numparams::Int = length(p)
numindvar::Int = length(u0)
# if there is an analytical Jacobian provided, we are not going to do automatic `jac*vec`
isautojacvec = DiffEqBase.has_jac(f) ? false : get_jacvec(alg)
J = isautojacvec ? nothing : similar(sol.prob.u0, numindvar, numindvar)
Expand All @@ -47,7 +47,7 @@ function ODEAdjointSensitivityFunction(f,jac,paramjac,uf,pf,g,u0,
nothing
end
dg_val = similar(u0, numindvar) # number of funcs size
ODEAdjointSensitivityFunction(f,jac,paramjac,uf,pf,g,J,pJ,dg_val,
return ODEAdjointSensitivityFunction(f,jac,paramjac,uf,pf,g,J,pJ,dg_val,
jac_config,g_grad_config,paramjac_config,
alg,numparams,numindvar,f_cache,
discrete,y,sol,dg,mm,checkpoints,integrator)
Expand All @@ -58,6 +58,7 @@ function (S::ODEAdjointSensitivityFunction)(du,u,p,t)
idx = length(S.y)
y = S.y
isautojacvec = DiffEqBase.has_jac(S.f) ? false : get_jacvec(S.alg)
sol = S.sol

if isbcksol(S.alg)
λ = @view u[1:idx]
Expand All @@ -67,7 +68,7 @@ function (S::ODEAdjointSensitivityFunction)(du,u,p,t)
_y = @view u[end-idx+1:end]
dy = @view du[end-idx+1:end]
copyto!(y, _y)
isautojacvec || S.sol.prob.f(dy, _y, p, t)
isautojacvec || sol.prob.f(dy, _y, p, t)
else
if ischeckpointing(S.alg)
# assuming that in the forward direction `t0` < `t1`, and the
Expand All @@ -77,7 +78,7 @@ function (S::ODEAdjointSensitivityFunction)(du,u,p,t)
dt = t-t0
integrator = S.integrator
if abs(dt) > integrator.opts.dtmin
S.sol(integrator.u, t0)
sol(integrator.u, t0)
copyto!(integrator.uprev, integrator.u)
integrator.t = t0
# set `iter` to some arbitrary integer so that there won't be max maxiters error
Expand All @@ -86,10 +87,10 @@ function (S::ODEAdjointSensitivityFunction)(du,u,p,t)
step!(integrator, dt, true)
# `integrator.u` is aliased to `y`
else
S.sol(y,t)
sol(y,t)
end
else
S.sol(y,t)
sol(y,t)
end
if isquad(S.alg)
λ = u
Expand Down Expand Up @@ -152,11 +153,10 @@ function (S::ODEAdjointSensitivityFunction)(du,u,p,t)
end

# g is either g(t,u,p) or discrete g(t,u,i)
function ODEAdjointProblem(sol,g,t=nothing,dg=nothing,
@noinline function ODEAdjointProblem(sol,g,t=nothing,dg=nothing,
alg=SensitivityAlg();
checkpoints=sol.t,
callback=CallbackSet(),mass_matrix=I)

f = sol.prob.f
tspan = (sol.prob.tspan[2],sol.prob.tspan[1])
t != nothing && (tspan = (t[end],t[1]))
Expand All @@ -166,25 +166,28 @@ function ODEAdjointProblem(sol,g,t=nothing,dg=nothing,
p = sol.prob.p
# if there is an analytical Jacobian provided, we are not going to do automatic `jac*vec`
isautojacvec = DiffEqBase.has_jac(f) ? false : get_jacvec(alg)
p == nothing && error("You must have parameters to use parameter sensitivity calculations!")
uf = DiffEqDiffTools.UJacobianWrapper(f,tspan[2],p)
pg = UGradientWrapper(g,tspan[2],p)
p === nothing && error("You must have parameters to use parameter sensitivity calculations!")

u0 = zero(sol.prob.u0)

if DiffEqBase.has_jac(f) || isautojacvec
jac_config = nothing
uf = nothing
else
jac_config = build_jac_config(alg,uf,u0)
uf = DiffEqDiffTools.UJacobianWrapper(f,tspan[2],p)
end

if !discrete
if dg != nothing || isautojacvec
pg = nothing
pg_config = nothing
else
pg = UGradientWrapper(g,tspan[2],p)
pg_config = build_grad_config(alg,pg,u0,p)
end
else
pg = nothing
pg_config = nothing
end

Expand All @@ -200,7 +203,7 @@ function ODEAdjointProblem(sol,g,t=nothing,dg=nothing,
end
end

len = isquad(alg) ? length(u0) : length(u0)+length(p)
len::Int = isquad(alg) ? length(u0) : length(u0)+length(p)
λ = similar(u0, len)
sense = ODEAdjointSensitivityFunction(f,f.jac,f.paramjac,
uf,pf,pg,u0,jac_config,pg_config,paramjac_config,
Expand Down Expand Up @@ -331,7 +334,7 @@ function adjoint_sensitivities(sol,alg,g,t=nothing,dg=nothing;
!isq && return adj_sol[end][(1:length(sol.prob.p)) .+ length(sol.prob.u0)]'
integrand = AdjointSensitivityIntegrand(sol,adj_sol)

if t == nothing
if t === nothing
res,err = quadgk(integrand,sol.prob.tspan[1],sol.prob.tspan[2],
atol=iabstol,rtol=ireltol)
else
Expand Down
14 changes: 7 additions & 7 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ Base.@pure function determine_chunksize(u,CS)
end
end

Base.@pure alg_autodiff(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = AD
Base.@pure get_chunksize(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = CS
Base.@pure diff_type(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = FDT
Base.@pure get_jacvec(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.autojacvec
Base.@pure isquad(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.quad
Base.@pure isbcksol(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.backsolve
Base.@pure ischeckpointing(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.checkpointing
@inline alg_autodiff(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = AD
@inline get_chunksize(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = CS
@inline diff_type(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = FDT
@inline get_jacvec(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.autojacvec
@inline isquad(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.quad
@inline isbcksol(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.backsolve
@inline ischeckpointing(alg::SensitivityAlg{CS,AD,FDT}) where {CS,AD,FDT} = alg.checkpointing

function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
fx::AbstractArray{<:Number}, alg::SensitivityAlg, jac_config)
Expand Down
6 changes: 3 additions & 3 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ easy_res3 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(quad=false,backsolve=false))
easy_res4 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(backsolve=true))
easy_res5 = adjoint_sensitivities(sol,Kvaerno5(nlsolve=NLAnderson(), smooth_est=false),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(backsolve=true))
easy_res5 = adjoint_sensitivities(sol,Kvaerno5(nlsolve=NLAnderson(), smooth_est=false),dg,t,abstol=1e-12,
reltol=1e-10,iabstol=1e-14,ireltol=1e-12,sensealg=SensitivityAlg(backsolve=true))
easy_res6 = adjoint_sensitivities(solb,Tsit5(),dg,t,abstol=1e-14,
reltol=1e-14,iabstol=1e-14,ireltol=1e-12,
sensealg=SensitivityAlg(checkpointing=true,quad=true),
Expand All @@ -56,7 +56,7 @@ res,err = quadgk(integrand,0.0,10.0,atol=1e-14,rtol=1e-12)
@test isapprox(res, easy_res2, rtol = 1e-10)
@test isapprox(res, easy_res3, rtol = 1e-10)
@test isapprox(res, easy_res4, rtol = 1e-10)
@test isapprox(res, easy_res5, rtol = 1e-9)
@test isapprox(res, easy_res5, rtol = 1e-7)
@test isapprox(res, easy_res6, rtol = 1e-9)
@test isapprox(res, easy_res7, rtol = 1e-9)

Expand Down

0 comments on commit 21540e3

Please sign in to comment.