Skip to content

Commit

Permalink
Merge pull request #95 from SciML/remake
Browse files Browse the repository at this point in the history
specialize ODEProblem remake for compile times
  • Loading branch information
ChrisRackauckas authored Aug 7, 2021
2 parents 2df7c8c + 488a7f1 commit 724a9af
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
38 changes: 28 additions & 10 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,37 @@ end

isrecompile(prob::ODEProblem{iip}) where {iip} = (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true

function remake(thing::ODEProblem; kwargs...)
T = remaker_of(thing)
tup = merge(merge(struct_as_namedtuple(thing),thing.kwargs),kwargs)
if !isrecompile(thing)
if isinplace(thing)
f = wrapfun_iip(unwrap_fw(tup.f.f),(tup.u0,tup.u0,tup.p,tup.tspan[1]))
function remake(prob::ODEProblem; f=missing,
u0=missing,
tspan=missing,
p=missing,
kwargs...)
if f === missing
f = prob.f
elseif !isrecompile(prob)
if isinplace(prob)
f = wrapfun_iip(unwrap_fw(f),(u0,u0,p,tspan[1]))
else
f = wrapfun_oop(unwrap_fw(tup.f.f),(tup.u0,tup.p,tup.tspan[1]))
f = wrapfun_oop(unwrap_fw(f),(u0,p,tspan[1]))
end
tup2 = (f = convert(ODEFunction{isinplace(thing)},f),)
tup = merge(tup, tup2)
f = convert(ODEFunction{isinplace(prob)},f)
elseif prob.f isa ODEFunction # avoid the SplitFunction etc. cases
f = convert(ODEFunction{isinplace(prob)},f)
end

if u0 === missing
u0 = prob.u0
end

if tspan === missing
tspan = prob.tspan
end

if p === missing
p = prob.p
end
T(; tup...)

ODEProblem{isinplace(prob)}(f,u0,tspan,p,prob.problem_type;prob.kwargs..., kwargs...)
end

function remake(thing::AbstractJumpProblem; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ eqs = [D(x) ~ σ*(y-x),
D(y) ~ x*-z)-y,
D(z) ~ x*y - β*z]

lorenz1 = ODESystem(eqs,name=:lorenz1)
lorenz2 = ODESystem(eqs,name=:lorenz2)
@named lorenz1 = ODESystem(eqs)
@named lorenz2 = ODESystem(eqs)

@parameters γ
@variables a(t) α(t)
connections = [0 ~ lorenz1.x + lorenz2.y + a*γ,
α ~ 2lorenz1.x + a*γ]
sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2])
@named sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2])
sys_simplified = structural_simplify(sys)

u0 = [lorenz1.x => 1.0,
Expand Down

0 comments on commit 724a9af

Please sign in to comment.