Skip to content

Commit

Permalink
refactor: remove ParameterIndexingProxy, use new getp and setp from SII
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2023
1 parent a4f2800 commit d1db9b7
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 91 deletions.
9 changes: 4 additions & 5 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,6 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)
elseif sym == SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(A)
else
return getfield(A, sym)
end
Expand All @@ -462,8 +460,8 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymboli
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `$(nameof(typeof(A))).$(SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[$sym]` for parameter indexing.", :parameter_getindex)
return ParameterIndexingProxy(A)[sym]
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
Expand Down Expand Up @@ -502,7 +500,8 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
A.p[parameter_index(A, sym)] = val
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
19 changes: 4 additions & 15 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p

function Base.propertynames(prob::AbstractSciMLProblem)
(fieldnames(typeof(prob))..., :ps)
end

function Base.getproperty(prob::AbstractSciMLProblem, sym::Symbol)
if sym == SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(prob)
else
return getfield(prob, sym)
end
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob.f, sym)
return prob.u0[variable_index(prob.f, sym)]
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `$(nameof(typeof(prob))).$(SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[$sym]` for parameter indexing.", :parameter_getindex)
return getproperty(prob, SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[sym]
Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex)
return getp(prob, sym)(prob)
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
Expand Down Expand Up @@ -49,7 +37,8 @@ function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if is_variable(prob.f, sym)
prob.u0[variable_index(prob.f, sym)] = val
elseif is_parameter(prob.f, sym)
prob.p[parameter_index(prob.f, sym)] = val
Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)
setp(prob, sym)(prob, val)
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
1 change: 0 additions & 1 deletion src/solutions/dae_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Sy
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
end
return getfield(x, s)
end
Expand Down
2 changes: 0 additions & 2 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(x)
end
return getfield(x, s)
end
Expand Down
2 changes: 0 additions & 2 deletions src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractOptimizationSoluti
Base.depwarn("`sol.prob` is deprecated. Use getters like `get_p` or `get_syms` on `sol` instead.",
"sol.prob")
return getfield(x, :cache)
elseif s === SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(x)
end
return getfield(x, s)
end
Expand Down
2 changes: 0 additions & 2 deletions src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::S
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(x)
end
return getfield(x, s)
end
Expand Down
12 changes: 2 additions & 10 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,13 @@ SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false
# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::AbstractSolution) = true

function Base.getproperty(A::AbstractSolution, sym::Symbol)
if sym == SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(A)
else
return getfield(A, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `$(nameof(typeof(A))).$(SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[$sym]` for parameter indexing.", :parameter_getindex)
return getproperty(A, SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[sym]
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p)
else
Expand Down
86 changes: 43 additions & 43 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, SymbolicIndexingInterface, Test

### Tests on non-layered model (everything should work). ###

Expand All @@ -21,30 +21,30 @@ integrator = init(oprob, Rodas4())
@test_deprecated integrator[a]
@test_deprecated integrator[population_model.a]
@test_deprecated integrator[:a]
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 2.0
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 1.0
@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 1.0
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0

step!(integrator, 100.0, true)

@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 2.0
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 1.0
@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 1.0
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0
@test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0
@test getp(population_model, b)(integrator) == getp(population_model, population_model.b)(integrator) == getp(population_model, :b)(integrator) == 1.0
@test getp(population_model, c)(integrator) == getp(population_model, population_model.c)(integrator) == getp(population_model, :c)(integrator) == 1.0
@test getp(population_model, d)(integrator) == getp(population_model, population_model.d)(integrator) == getp(population_model, :d)(integrator) == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] != 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] != 1.0

integrator[a] = 10.0
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 10.0
integrator[population_model.b] = 20.0
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 20.0
integrator[c] = 30.0
@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 30.0
setp(oprob, a)(integrator, 10.0)
@test getp(integrator, a)(integrator) == getp(integrator, population_model.a)(integrator) == getp(integrator, :a)(integrator) == 10.0
setp(population_model, population_model.b)(integrator, 20.0)
@test getp(integrator, b)(integrator) == getp(integrator, population_model.b)(integrator) == getp(integrator, :b)(integrator) == 20.0
setp(integrator, c)(integrator, 30.0)
@test getp(integrator, c)(integrator) == getp(integrator, population_model.c)(integrator) == getp(integrator, :c)(integrator) == 30.0

integrator[s1] = 10.0
@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 10.0
Expand All @@ -62,19 +62,19 @@ integrator = init(sprob, ImplicitEM())

step!(integrator, 100.0, true)

@test integrator.ps[a] == integrator.ps[noisy_population_model.a] == integrator.ps[:a] == 2.0
@test integrator.ps[b] == integrator.ps[noisy_population_model.b] == integrator.ps[:b] == 1.0
@test integrator.ps[c] == integrator.ps[noisy_population_model.c] == integrator.ps[:c] == 1.0
@test integrator.ps[d] == integrator.ps[noisy_population_model.d] == integrator.ps[:d] == 1.0
@test getp(sprob, a)(integrator) == getp(sprob, noisy_population_model.a)(integrator) == getp(sprob, :a)(integrator) == 2.0
@test getp(sprob, b)(integrator) == getp(sprob, noisy_population_model.b)(integrator) == getp(sprob, :b)(integrator) == 1.0
@test getp(sprob, c)(integrator) == getp(sprob, noisy_population_model.c)(integrator) == getp(sprob, :c)(integrator) == 1.0
@test getp(sprob, d)(integrator) == getp(sprob, noisy_population_model.d)(integrator) == getp(sprob, :d)(integrator) == 1.0
@test integrator[s1] == integrator[noisy_population_model.s1] == integrator[:s1] != 2.0
@test integrator[s2] == integrator[noisy_population_model.s2] == integrator[:s2] != 1.0

integrator[a] = 10.0
@test integrator.ps[a] == integrator.ps[noisy_population_model.a] == integrator.ps[:a] == 10.0
integrator[noisy_population_model.b] = 20.0
@test integrator.ps[b] == integrator.ps[noisy_population_model.b] == integrator.ps[:b] == 20.0
integrator[c] = 30.0
@test integrator.ps[c] == integrator.ps[noisy_population_model.c] == integrator.ps[:c] == 30.0
setp(integrator, a)(integrator, 10.0)
@test getp(noisy_population_model, a)(integrator) == getp(noisy_population_model, noisy_population_model.a)(integrator) == getp(noisy_population_model, :a)(integrator) == 10.0
setp(sprob, noisy_population_model.b)(integrator, 20.0)
@test getp(noisy_population_model, b)(integrator) == getp(noisy_population_model, noisy_population_model.b)(integrator) == getp(noisy_population_model, :b)(integrator) == 20.0
setp(noisy_population_model, c)(integrator, 30.0)
@test getp(noisy_population_model, c)(integrator) == getp(noisy_population_model, noisy_population_model.c)(integrator) == getp(noisy_population_model, :c)(integrator) == 30.0

integrator[s1] = 10.0
@test integrator[s1] == integrator[noisy_population_model.s1] == integrator[:s1] == 10.0
Expand Down Expand Up @@ -137,13 +137,13 @@ step!(integrator, 100.0, true)
@test integrator[lorenz1.x] isa Real
@test integrator[t] isa Real
@test integrator[α] isa Real
@test integrator.ps[γ] isa Real
@test integrator.ps[γ] == 2.0
@test integrator.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple
@test getp(prob, γ)(integrator) isa Real
@test getp(prob, γ)(integrator) == 2.0
@test getp(prob, (lorenz1.σ, lorenz1.ρ))(integrator) isa Tuple

@test length(integrator[[lorenz1.x, lorenz2.x]]) == 2
@test integrator.ps[[γ, lorenz1.σ]] isa Vector{Float64}
@test length(integrator.ps[[γ, lorenz1.σ]]) == 2
@test getp(integrator, [γ, lorenz1.σ])(integrator) isa Vector{Float64}
@test length(getp(integrator, [γ, lorenz1.σ])(integrator)) == 2

@variables q(t)[1:2] = [1.0, 2.0]
eqs = [D(q[1]) ~ 2q[1]
Expand Down Expand Up @@ -178,23 +178,23 @@ integrator2 = init(prob2, Tsit5())
end

# Tests various interface methods:
@test_throws Any integrator.ps[σ]
@test in(integrator.ps[lorenz1.σ], integrator.p)
@test in(integrator.ps[lorenz2.σ], integrator.p)
@test_throws Any sol.ps[]
@test_throws Any getp(sys, σ)(integrator)
@test in(getp(sys, lorenz1.σ)(integrator), integrator.p)
@test in(getp(sys, lorenz2.σ)(integrator), integrator.p)
@test_throws Any getp(sol, )(sol)

@test_throws Any integrator[x]
@test in(integrator[lorenz1.x], integrator.u)
@test in(integrator[lorenz2.x], integrator.u)
@test_throws Any sol[:x]

@test_throws Any integrator.ps[σ]=2.0
integrator[lorenz1.σ] = 2.0
@test integrator.ps[lorenz1.σ] == 2.0
@test integrator.ps[lorenz2.σ] != 2.0
integrator[lorenz2.σ] = 2.0
@test integrator.ps[lorenz2.σ] == 2.0
@test_throws Any sol.ps[]
@test_throws Any getp(sol, :x)(sol)

@test_throws Any setp(integrator, σ)(integrator, 2.0)
setp(integrator, lorenz1.σ)(integrator, 2.0)
@test getp(integrator, lorenz1.σ)(integrator) == 2.0
@test getp(integrator, lorenz2.σ)(integrator) != 2.0
setp(integrator, lorenz2.σ)(integrator, 2.0)
@test getp(integrator, lorenz2.σ)(integrator) == 2.0
@test_throws Any getp(sol, )(sol)

@test_throws Any integrator[x]=2.0
integrator[lorenz1.x] = 2.0
Expand Down
22 changes: 11 additions & 11 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, Test
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test
using Optimization, OptimizationOptimJL

@parameters t σ ρ β
Expand Down Expand Up @@ -39,10 +39,10 @@ tspan = (0.0, 100.0)
prob = ODEProblem(sys_simplified, u0, tspan, p)
sol = solve(prob, Rodas4())

@test_throws Any sol.ps[b]
@test_throws Any sol.ps[b, 1]
@test_throws Any sol.ps[b, 1:5]
@test_throws Any sol.ps[b, [1, 2, 3]]
@test_throws Any sol[b]
@test_throws Any sol[b, 1]
@test_throws Any sol[b, 1:5]
@test_throws Any sol[b, [1, 2, 3]]
@test_throws Any sol['a']
@test_throws Any sol['a', 1]
@test_throws Any sol['a', 1:5]
Expand All @@ -67,9 +67,9 @@ sol = solve(prob, Rodas4())
@test sol[α] isa Vector
@test sol[α, 3] isa Float64
@test length(sol[α, 5:10]) == 6
@test sol.ps[γ] isa Real
@test sol.ps[γ] == 2.0
@test sol.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple
@test getp(prob, γ)(sol) isa Real
@test getp(prob, γ)(sol) == 2.0
@test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple

@test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}}
@test length(sol[[lorenz1.x, lorenz2.x]]) == length(sol)
Expand Down Expand Up @@ -201,7 +201,7 @@ D = Differential(t)
@named fol = ODESystem([D(x) ~ (1 - x) / tau])
prob = ODEProblem(fol, [x => 0.0], (0.0, 10.0), [tau => 3.0])
sol = solve(prob, Tsit5())
@test sol.ps[tau] == 3
@test getp(fol, tau)(sol) == 3

@testset "OptimizationSolution" begin
@variables begin
Expand All @@ -219,6 +219,6 @@ sol = solve(prob, Tsit5())
sol = solve(prob, GradientDescent())
@test sol[x]1 atol=1e-3
@test sol[y]1 atol=1e-3
@test sol.ps[a] 1
@test sol.ps[b] 100
@test getp(sys, a)(sol) 1
@test getp(sys, b)(sol) 100
end

0 comments on commit d1db9b7

Please sign in to comment.