diff --git a/src/remake.jl b/src/remake.jl index ecbb46ce1..d23fb94d7 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -33,19 +33,37 @@ end isrecompile(prob::ODEProblem{iip}) where {iip} = (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true -function remake(thing::ODEProblem; kwargs...) - T = remaker_of(thing) - tup = merge(merge(struct_as_namedtuple(thing),thing.kwargs),kwargs) - if !isrecompile(thing) - if isinplace(thing) - f = wrapfun_iip(unwrap_fw(tup.f.f),(tup.u0,tup.u0,tup.p,tup.tspan[1])) +function remake(prob::ODEProblem; f=missing, + u0=missing, + tspan=missing, + p=missing, + kwargs...) + if f === missing + f = prob.f + elseif !isrecompile(prob) + if isinplace(prob) + f = wrapfun_iip(unwrap_fw(f),(u0,u0,p,tspan[1])) else - f = wrapfun_oop(unwrap_fw(tup.f.f),(tup.u0,tup.p,tup.tspan[1])) + f = wrapfun_oop(unwrap_fw(f),(u0,p,tspan[1])) end - tup2 = (f = convert(ODEFunction{isinplace(thing)},f),) - tup = merge(tup, tup2) + f = convert(ODEFunction{isinplace(prob)},f) + elseif prob.f isa ODEFunction # avoid the SplitFunction etc. cases + f = convert(ODEFunction{isinplace(prob)},f) + end + + if u0 === missing + u0 = prob.u0 + end + + if tspan === missing + tspan = prob.tspan + end + + if p === missing + p = prob.p end - T(; tup...) + + ODEProblem{isinplace(prob)}(f,u0,tspan,p,prob.problem_type;prob.kwargs..., kwargs...) end function remake(thing::AbstractJumpProblem; kwargs...) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index eefc92e71..6c53af0d2 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -8,14 +8,14 @@ eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ x*y - β*z] -lorenz1 = ODESystem(eqs,name=:lorenz1) -lorenz2 = ODESystem(eqs,name=:lorenz2) +@named lorenz1 = ODESystem(eqs) +@named lorenz2 = ODESystem(eqs) @parameters γ @variables a(t) α(t) connections = [0 ~ lorenz1.x + lorenz2.y + a*γ, α ~ 2lorenz1.x + a*γ] -sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2]) +@named sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2]) sys_simplified = structural_simplify(sys) u0 = [lorenz1.x => 1.0,