From 06aaf15bf01d1cde94e48b7450329231c2765473 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 12 Dec 2024 14:14:28 +0530 Subject: [PATCH] fix: better handle split functions in `remake(::AbstractSciMLFunction)` --- src/remake.jl | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 412e5a011..46cbc7765 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -122,19 +122,23 @@ end by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword argument is used to override `func.f2`, and `f` is used for `func.f1`. If -`f isa AbstractSciMLFunction`, properties of `f` will override those of `func` (but not ones -provided via keyword arguments). Properties of `f` that are `nothing` will fall back to those -in `func` (unless provided via keyword arguments). If `f` is a different type of -`AbstractSciMLFunction` from `func`, the returned function will be of the kind of `f`. +`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will +override those of `func` (but not ones provided via keyword arguments). Properties of `f` that +are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If +`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be +of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and +`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace` +and `specialization` as `func`. """ -function remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...) +function remake( + func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...) # retain iip and spec of original function iip = isinplace(func) spec = specialization(func) # retain properties of original function props = getproperties(func) - if f === missing + if f === missing || is_split_function(func) # if no `f` is provided, create the same type of SciMLFunction T = parameterless_type(func) f = isdefined(func, :f) ? func.f : func.f1 @@ -153,32 +157,36 @@ function remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = miss # minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data` # it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap # in this case and forget about properties. - while f isa AbstractSciMLFunction + while !is_split_function(T) && f isa AbstractSciMLFunction f = isdefined(f, :f) ? f.f : f.f1 end props = @delete props.f props = @delete props.f1 - if isdefined(func, :g) - # For SDEs/SDDEs where `g` is not a keyword - g = coalesce(g, func.g) + if is_split_function(T) + f = f isa AbstractSciMLOperator ? f : split_function_f_wrapper(T){iip, spec}(f) + end - props = @delete props.g - T{iip, spec}(f, g; props..., kwargs...) - elseif isdefined(func, :f2) + args = (f,) + if is_split_function(T) # For SplitFunction # we don't do the same thing as `g`, because for SDEs `g` is # stored in the problem as well, whereas for Split ODEs etc # f2 is a part of the function. Thus, if the user provides # a SciMLFunction for `f` which contains `f2` we use that. f2 = coalesce(f2, get(props, :f2, missing), func.f2) - + f2 = split_function_f_wrapper(T){iip, spec}(f2) props = @delete props.f2 - T{iip, spec}(f, f2; props..., kwargs...) - else - T{iip, spec}(f; props..., kwargs...) + args = (args..., f2) + end + if isdefined(func, :g) + # For SDEs/SDDEs where `g` is not a keyword + g = coalesce(g, func.g) + props = @delete props.g + args = (args..., g) end + T{iip, spec}(args...; props..., kwargs...) end """