From a7f661a616b2f90ea1ccacd38373e9b3ee7be65b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Mar 2024 12:08:26 +0530 Subject: [PATCH] fixup! fix: fix remake for symbolic indexing, add tests for remake --- src/problems/basic_problems.jl | 2 ++ src/problems/noise_problems.jl | 5 +++++ src/remake.jl | 12 +++--------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index c94f3795f9..dc32d74260 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -560,3 +560,5 @@ struct SampledIntegralProblem{Y, X, K} <: AbstractIntegralProblem{false} new{typeof(y), typeof(x), typeof(kwargs)}(y, x, dim, kwargs) end end + +SymbolicIndexingInterface.state_values(::AbstractIntegralProblem) = nothing diff --git a/src/problems/noise_problems.jl b/src/problems/noise_problems.jl index 2c5977599d..c289eb553f 100644 --- a/src/problems/noise_problems.jl +++ b/src/problems/noise_problems.jl @@ -12,3 +12,8 @@ end _tspan = promote_tspan(tspan) NoiseProblem{typeof(noise), typeof(_tspan), typeof(kwargs)}(noise, _tspan, seed, kwargs) end + +SymbolicIndexingInterface.parameter_values(::NoiseProblem) = nothing +SymbolicIndexingInterface.is_parameter(::NoiseProblem) = false +SymbolicIndexingInterface.parameter_index(::NoiseProblem) = nothing +SymbolicIndexingInterface.parameter_symbols(::NoiseProblem) = [] diff --git a/src/remake.jl b/src/remake.jl index 4a88999e7c..18e368d2a0 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -344,20 +344,14 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) end function updated_u0(prob, u0, p) - if u0 === missing - return copy(state_values(prob)) + if u0 === missing || u0 isa Function + return state_values(prob) end if u0 isa Number return u0 end if eltype(u0) <: Pair u0 = Dict(u0) - elseif u0 isa AbstractArray - if length(u0) == length(state_values(prob)) - return u0 - else - throw(ArgumentError("Invalid value for u0: $u0. New length ($length(u0)) does not match length of current problem ($(length(state_values(prob))))")) - end else return u0 end @@ -386,7 +380,7 @@ end function updated_p(prob, p; interpret_symbolicmap = true) if p === missing - return copy(parameter_values(prob)) + return parameter_values(prob) end if eltype(p) <: Pair if interpret_symbolicmap