Skip to content

Commit

Permalink
feat: add lazy initialization to remake
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2024
1 parent a6a4a1f commit f09aee5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,15 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,

return u0, p, success
end

function is_trivial_initialization(initdata::OverrideInitData)
state_values(initdata.initializeprob) === nothing
end

function is_trivial_initialization(f::AbstractSciMLFunction)
has_initialization_data(f) && is_trivial_initialization(f.initialization_data)
end

function is_trivial_initialization(prob::AbstractSciMLProblem)
is_trivial_initialization(prob.f)
end
11 changes: 10 additions & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing,
interpret_symbolicmap = true,
build_initializeprob = true,
use_defaults = false,
lazy_initialization = !is_trivial_initialization(prob),
_kwargs...)
if tspan === missing
tspan = prob.tspan
Expand Down Expand Up @@ -170,13 +171,21 @@ function remake(prob::ODEProblem; f = missing,
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

if kwargs === missing
prob = if kwargs === missing
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if !lazy_initialization
u0, p, _ = get_initial_values(prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end

"""
Expand Down
9 changes: 9 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,12 @@ end
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
end

@testset "Lazy initialization" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p = missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t)
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2.ps[p] 3.0
end

0 comments on commit f09aee5

Please sign in to comment.