Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new remake(::AbstractSciMLFunction), fix some remake bugs. #891

Merged
merged 13 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset, @delete
import Accessors: @set, @reset, @delete, @insert
using Expronicon.ADT: @match

using Reexport
Expand Down
6 changes: 4 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ function evaluate_f(
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
end

function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
function evaluate_f(
integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
function evaluate_f(integrator::AbstractSDDEIntegrator,
prob::AbstractSDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

Expand Down
224 changes: 112 additions & 112 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,105 @@ function remake(
_remake_internal(prob; kwargs..., p)
end

"""
$(TYPEDSIGNATURES)

A utility function which merges two `NamedTuple`s `a` and `b`, assuming that the
keys of `a` are a subset of those of `b`. Values in `b` take priority over those
in `a`, except if they are `nothing`. Keys not present in `a` are assumed to have
a value of `nothing`.
"""
function _similar_namedtuple_merge_ignore_nothing(a::NamedTuple, b::NamedTuple)
ks = fieldnames(typeof(b))
return NamedTuple{ks}(ntuple(Val(length(ks))) do i
something(get(b, ks[i], nothing), get(a, ks[i], nothing), Some(nothing))
end)
end

Comment on lines +102 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just (;a..., b...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because nothing values in b will override values in a. This function makes it so that if a key has the value nothing in b, it uses the value from a.

"""
remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)

`remake` the given `func`. Return an `AbstractSciMLFunction` of the same kind, `isinplace` and
`specialization` as `func`. Retain the properties of `func`, except those that are overridden
by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument
is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword
argument is used to override `func.f2`, and `f` is used for `func.f1`. If
`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will
override those of `func` (but not ones provided via keyword arguments). Properties of `f` that
are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If
`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be
of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and
`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace`
and `specialization` as `func`.
"""
function remake(
func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
# retain iip and spec of original function
iip = isinplace(func)
spec = specialization(func)
# retain properties of original function
props = getproperties(func)

if f === missing || is_split_function(func)
# if no `f` is provided, create the same type of SciMLFunction
T = parameterless_type(func)
f = isdefined(func, :f) ? func.f : func.f1
elseif f isa AbstractSciMLFunction
# if `f` is a SciMLFunction, create that type
T = parameterless_type(f)
# properties of `f` take priority over those in the existing `func`
# ignore properties of `f` which are `nothing` but present in `func`
props = _similar_namedtuple_merge_ignore_nothing(props, getproperties(f))
f = isdefined(f, :f) ? f.f : f.f1
else
# if `f` is provided but not a SciMLFunction, create the same type
T = parameterless_type(func)
end

# minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data`
# it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap
# in this case and forget about properties.
while !is_split_function(T) && f isa AbstractSciMLFunction
f = isdefined(f, :f) ? f.f : f.f1
end

props = @delete props.f
props = @delete props.f1

args = (f,)
if is_split_function(T)
# for DynamicalSDEFunction and SplitFunction
if isdefined(props, :cache)
props = @insert props._func_cache = props.cache
props = @delete props.cache
end

# `f1` and `f2` are wrapped in another SciMLFunction, unless they're
# already wrapped in the appropriate type or are an `AbstractSciMLOperator`
if !(f isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
f = split_function_f_wrapper(T){iip, spec}(f)
end
# For SplitFunction
# we don't do the same thing as `g`, because for SDEs `g` is
# stored in the problem as well, whereas for Split ODEs etc
# f2 is a part of the function. Thus, if the user provides
# a SciMLFunction for `f` which contains `f2` we use that.
f2 = coalesce(f2, get(props, :f2, missing), func.f2)
if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
f2 = split_function_f_wrapper(T){iip, spec}(f2)
end
props = @delete props.f2
args = (args..., f2)
end
if isdefined(func, :g)
# For SDEs/SDDEs where `g` is not a keyword
g = coalesce(g, func.g)
props = @delete props.g
args = (args..., g)
end
T{iip, spec}(args...; props..., kwargs...)
end

"""
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, _kwargs...)
Expand Down Expand Up @@ -135,53 +234,26 @@ function remake(prob::ODEProblem; f = missing,
initialization_data = nothing
end

if f === missing
if specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_iip(
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1])); initialization_data)
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1])); initialization_data)
end
else
_f = prob.f
if __has_initialization_data(_f)
props = getproperties(_f)
@reset props.initialization_data = initialization_data
props = values(props)
_f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...)
end
end
elseif f isa AbstractODEFunction
_f = f
elseif specialization(prob.f) === FunctionWrapperSpecialize
f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

if specialization(f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(newu0, newu0, newp,
ptspan[1])))
f = remake(
f; f = wrapfun_iip(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(newu0, newp, ptspan[1])))
f = remake(
f; f = wrapfun_oop(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
end
else
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

prob = if kwargs === missing
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
ODEProblem{iip}(
f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if lazy_initialization === nothing
Expand Down Expand Up @@ -395,42 +467,6 @@ function remake(prob::SDEProblem;
return prob
end

"""
remake(func::SDEFunction; f = missing, g = missing,
mass_matrix = missing, analytic = missing, kwargs...)

Remake the given `SDEFunction`.
"""
function remake(func::Union{SDEFunction, SDDEFunction};
f = missing,
g = missing,
mass_matrix = missing,
analytic = missing,
sys = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f
props = @delete props.g
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
@reset props.analytic = coalesce(analytic, func.analytic)
@reset props.sys = coalesce(sys, func.sys)

if f === missing
f = func.f
end

if g === missing
g = func.g
end

if f isa AbstractSciMLFunction
f = f.f
end

T = func isa SDEFunction ? SDEFunction : SDDEFunction
return T{isinplace(func)}(f, g; props..., kwargs...)
end

function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
tspan = missing, p = missing, constant_lags = missing,
dependent_lags = missing, order_discontinuity_t0 = missing,
Expand Down Expand Up @@ -497,28 +533,6 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
return prob
end

function remake(func::DDEFunction;
f = missing,
mass_matrix = missing,
analytic = missing,
sys = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
@reset props.analytic = coalesce(analytic, func.analytic)
@reset props.sys = coalesce(sys, func.sys)

if f === missing
f = func.f
end
if f isa AbstractSciMLFunction
f = f.f
end

return DDEFunction{isinplace(func)}(f; props..., kwargs...)
end

function remake(prob::SDDEProblem;
f = missing,
g = missing,
Expand Down Expand Up @@ -706,6 +720,7 @@ function remake(prob::NonlinearProblem;
initialization_data = nothing
end

f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

if problem_type === missing
Expand Down Expand Up @@ -737,22 +752,6 @@ function remake(prob::NonlinearProblem;
return prob
end

function remake(func::NonlinearFunction;
f = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f

if f === missing
f = func.f
end
if f isa AbstractSciMLFunction
f = f.f
end

return NonlinearFunction{isinplace(func)}(f; props..., kwargs...)
end

"""
remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
kwargs = missing, _kwargs...)
Expand All @@ -775,6 +774,7 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
initialization_data = nothing
end

f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

prob = if kwargs === missing
Expand Down
14 changes: 14 additions & 0 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4601,6 +4601,20 @@ has_Wfact_t(f::JacobianWrapper) = has_Wfact_t(f.f)
has_paramjac(f::JacobianWrapper) = has_paramjac(f.f)
has_colorvec(f::JacobianWrapper) = has_colorvec(f.f)

is_split_function(x) = is_split_function(typeof(x))
is_split_function(::Type) = false
function is_split_function(::Type{T}) where {T <: Union{
SplitFunction, SplitSDEFunction, DynamicalODEFunction,
DynamicalDDEFunction, DynamicalSDEFunction}}
true
end

split_function_f_wrapper(::Type{<:SplitFunction}) = ODEFunction
split_function_f_wrapper(::Type{<:SplitSDEFunction}) = SDEFunction
split_function_f_wrapper(::Type{<:DynamicalODEFunction}) = ODEFunction
split_function_f_wrapper(::Type{<:DynamicalDDEFunction}) = DDEFunction
split_function_f_wrapper(::Type{<:DynamicalSDEFunction}) = DDEFunction

######### Additional traits

islinear(::AbstractDiffEqFunction) = false
Expand Down
11 changes: 11 additions & 0 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,14 @@ end
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
@test_nowarn remake(prob; u0 = [:x => nothing], p = [:a => nothing])
end

@testset "retain properties of `SciMLFunction` passed to `remake`" begin
u0 = [1.0; 2.0; 3.0]
p = [10.0, 20.0, 30.0]
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
fn = NonlinearFunction(nllorenz!; sys, resid_prototype = zeros(Float64, 3))
prob = NonlinearProblem(fn, u0, p)
fn2 = NonlinearFunction(nllorenz!; resid_prototype = zeros(Float32, 3))
prob2 = remake(prob; f = fn2)
@test prob2.f.resid_prototype isa Vector{Float32}
end
Loading