Skip to content

Commit

Permalink
fix: better handle split functions in remake(::AbstractSciMLFunction)
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 13, 2024
1 parent c4a4692 commit d04a6eb
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,23 @@ end
by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument
is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword
argument is used to override `func.f2`, and `f` is used for `func.f1`. If
`f isa AbstractSciMLFunction`, properties of `f` will override those of `func` (but not ones
provided via keyword arguments). Properties of `f` that are `nothing` will fall back to those
in `func` (unless provided via keyword arguments). If `f` is a different type of
`AbstractSciMLFunction` from `func`, the returned function will be of the kind of `f`.
`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will
override those of `func` (but not ones provided via keyword arguments). Properties of `f` that
are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If
`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be
of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and
`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace`
and `specialization` as `func`.
"""
function remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
function remake(
func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
# retain iip and spec of original function
iip = isinplace(func)
spec = specialization(func)
# retain properties of original function
props = getproperties(func)

if f === missing
if f === missing || is_split_function(func)
# if no `f` is provided, create the same type of SciMLFunction
T = parameterless_type(func)
f = isdefined(func, :f) ? func.f : func.f1
Expand All @@ -153,32 +157,36 @@ function remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = miss
# minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data`
# it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap
# in this case and forget about properties.
while f isa AbstractSciMLFunction
while !is_split_function(T) && f isa AbstractSciMLFunction
f = isdefined(f, :f) ? f.f : f.f1
end

props = @delete props.f
props = @delete props.f1

if isdefined(func, :g)
# For SDEs/SDDEs where `g` is not a keyword
g = coalesce(g, func.g)
if is_split_function(T)
f = f isa AbstractSciMLOperator ? f : split_function_f_wrapper(T){iip, spec}(f)
end

props = @delete props.g
T{iip, spec}(f, g; props..., kwargs...)
elseif isdefined(func, :f2)
args = (f,)
if is_split_function(T)
# For SplitFunction
# we don't do the same thing as `g`, because for SDEs `g` is
# stored in the problem as well, whereas for Split ODEs etc
# f2 is a part of the function. Thus, if the user provides
# a SciMLFunction for `f` which contains `f2` we use that.
f2 = coalesce(f2, get(props, :f2, missing), func.f2)

f2 = split_function_f_wrapper(T){iip, spec}(f2)
props = @delete props.f2
T{iip, spec}(f, f2; props..., kwargs...)
else
T{iip, spec}(f; props..., kwargs...)
args = (args..., f2)
end
if isdefined(func, :g)
# For SDEs/SDDEs where `g` is not a keyword
g = coalesce(g, func.g)
props = @delete props.g
args = (args..., g)
end
T{iip, spec}(args...; props..., kwargs...)
end

"""
Expand Down

0 comments on commit d04a6eb

Please sign in to comment.