Skip to content

Commit

Permalink
feat: add support for ParameterIndexingProxy
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 5, 2024
1 parent d5edf1f commit 17e3481
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.3.2"
SymbolicIndexingInterface = "0.3.3"
Tables = "1.11"
TruncatedStacktraces = "1.4"
Zygote = "0.6.67"
Expand Down
2 changes: 2 additions & 0 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)
elseif sym === :ps
return ParameterIndexingProxy(A)

Check warning on line 455 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L455

Added line #L455 was not covered by tests
else
return getfield(A, sym)
end
Expand Down
2 changes: 2 additions & 0 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ function Base.getproperty(prob::IntegralProblem, name::Symbol)
domain = getfield(prob, :domain)
lb, ub = domain
return ub
elseif name === :ps
return ParameterIndexingProxy(prob)

Check warning on line 506 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L505-L506

Added lines #L505 - L506 were not covered by tests
end
return Base.getfield(prob, name)
end
Expand Down
7 changes: 7 additions & 0 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, sym::Symbol)
if sym === :ps
return ParameterIndexingProxy(prob)

Check warning on line 3 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L3

Added line #L3 was not covered by tests
end
return getfield(prob, sym)
end

SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
Expand Down
2 changes: 2 additions & 0 deletions src/solutions/dae_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Sy
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === :ps
return ParameterIndexingProxy(x)

Check warning on line 51 in src/solutions/dae_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/dae_solutions.jl#L51

Added line #L51 was not covered by tests
end
return getfield(x, s)
end
Expand Down
2 changes: 2 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === :ps
return ParameterIndexingProxy(x)

Check warning on line 125 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L125

Added line #L125 was not covered by tests
end
return getfield(x, s)
end
Expand Down
2 changes: 2 additions & 0 deletions src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractOptimizationSoluti
Base.depwarn("`sol.prob` is deprecated. Use getters like `get_p` or `get_syms` on `sol` instead.",
"sol.prob")
return getfield(x, :cache)
elseif s === :ps
return ParameterIndexingProxy(x)

Check warning on line 119 in src/solutions/optimization_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/optimization_solutions.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
end
return getfield(x, s)
end
Expand Down
2 changes: 2 additions & 0 deletions src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::S
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
return getfield(x, :stats)
elseif s === :ps
return ParameterIndexingProxy(x)

Check warning on line 59 in src/solutions/rode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/rode_solutions.jl#L59

Added line #L59 was not covered by tests
end
return getfield(x, s)
end
Expand Down
9 changes: 9 additions & 0 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ end

# SymbolicIndexingInterface.jl
const AbstractSolution = Union{AbstractTimeseriesSolution,AbstractNoTimeSolution}

Base.@propagate_inbounds function Base.getproperty(A::AbstractSolution, sym::Symbol)
if sym === :ps
return ParameterIndexingProxy(A)

Check warning on line 35 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L35

Added line #L35 was not covered by tests
else
return getfield(A, sym)
end
end

SymbolicIndexingInterface.symbolic_container(A::AbstractSolution) = A.prob.f
SymbolicIndexingInterface.parameter_values(A::AbstractSolution) = A.prob.p

Expand Down
14 changes: 14 additions & 0 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ integrator = init(oprob, Rodas4())
@test_throws Exception integrator[population_model.a]
@test_throws Exception integrator[:a]
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 2.0
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 1.0
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0
Expand All @@ -42,10 +46,15 @@ step!(integrator, 100.0, true)

setp(oprob, a)(integrator, 10.0)
@test getp(integrator, a)(integrator) == getp(integrator, population_model.a)(integrator) == getp(integrator, :a)(integrator) == 10.0
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 10.0
setp(population_model, population_model.b)(integrator, 20.0)
@test getp(integrator, b)(integrator) == getp(integrator, population_model.b)(integrator) == getp(integrator, :b)(integrator) == 20.0
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 20.0
setp(integrator, c)(integrator, 30.0)
@test getp(integrator, c)(integrator) == getp(integrator, population_model.c)(integrator) == getp(integrator, :c)(integrator) == 30.0
@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 30.0
integrator.ps[d] = 40.0
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 40.0

integrator[s1] = 10.0
@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 10.0
Expand Down Expand Up @@ -330,3 +339,8 @@ set_arr!(integrator, [[1.0, 2.0, 3.0], 1.0])
@test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0]
set_tuple!(integrator, ([2.0, 4.0, 6.0], 2.0))
@test get_tuple(integrator) == ([2.0, 4.0, 6.0], 2.0)
@test getp(sys, p)(integrator) == integrator.ps[p] == [1, 2, 3]
setp(sys, p)(integrator, [4, 5, 6])
@test getp(sys, p)(integrator) == integrator.ps[p] == [4, 5, 6]
integrator.ps[p] = [7, 8, 9]
@test getp(sys, p)(integrator) == integrator.ps[p] == [7, 8, 9]
54 changes: 41 additions & 13 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ oprob = ODEProblem(sys, u0, tspan, p, jac = true)
getσ1 = getp(sys, σ)
getσ2 = getp(sys, sys.σ)
getσ3 = getp(sys, )
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 28.0
getρ1 = getp(sys, ρ)
getρ2 = getp(sys, sys.ρ)
getρ3 = getp(sys, )
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 10.0
getβ1 = getp(sys, β)
getβ2 = getp(sys, sys.β)
getβ3 = getp(sys, )
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 8 / 3

@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
Expand All @@ -51,13 +51,20 @@ getβ3 = getp(sys, :β)

setσ = setp(sys, σ)
setσ(oprob, 10.0)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 10.0
setρ = setp(sys, sys.ρ)
setρ(oprob, 20.0)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 20.0
setβ = setp(sys, )
setβ(oprob, 30.0)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 30.0

oprob.ps[σ] = 11.0
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 11.0
oprob.ps[sys.ρ] = 21.0
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 21.0
oprob.ps[] = 31.0
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 31.0

oprob[x] = 10.0
@test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0
Expand Down Expand Up @@ -103,22 +110,28 @@ noiseeqs = [0.1 * x,
sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p)
u0

@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 28.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 10.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 8 / 3

@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0

setσ(sprob, 10.0)
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 10.0
setρ(sprob, 20.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 20.0
setp(noise_sys, noise_sys.ρ)(sprob, 25.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 25.0
setβ(sprob, 30.0)
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 30.0
sprob.ps[σ] = 11.0
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 11.0
sprob.ps[sys.ρ] = 21.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 21.0
sprob.ps[] = 31.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 31.0

sprob[x] = 10.0
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0
Expand Down Expand Up @@ -155,3 +168,18 @@ set_tuple!(sprob, 13.0)
@test get_tuple(sprob) == (13.0, 13.0)
set_tuple!(sprob, [10.0, 10.0])
@test get_tuple(sprob) == (10.0, 10.0)

using LinearAlgebra
@variables t
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
ps = @parameters p[1:3] = [1, 2, 3]
D = Differential(t)
eqs = [collect(D.(x) .~ x)
D(y) ~ norm(x) * y - x[1]]
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
prob = ODEProblem(sys, [], (0, 1.0))
@test getp(sys, p)(prob) == prob.ps[p] == [1, 2, 3]
setp(sys, p)(prob, [4, 5, 6])
@test getp(sys, p)(prob) == prob.ps[p] == [4, 5, 6]
prob.ps[p] = [7, 8, 9]
@test getp(sys, p)(prob) == prob.ps[p] == [7, 8, 9]
11 changes: 8 additions & 3 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ sol = solve(prob, Rodas4())
@test sol[α, 3] isa Float64
@test length(sol[α, 5:10]) == 6
@test getp(prob, γ)(sol) isa Real
@test getp(prob, γ)(sol) == getp(prob, )(sol) == 2.0
@test sol.ps[γ] isa Real
@test getp(prob, γ)(sol) == getp(prob, )(sol) == sol.ps[γ] == sol.ps[] == 2.0
@test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple
@test sol.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple

@test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}}
@test length(sol[[lorenz1.x, lorenz2.x]]) == length(sol)
Expand Down Expand Up @@ -179,8 +181,8 @@ sol10 = sol(0.1, idxs = 2)
getx = getu(sys_simplified, lorenz1.x)
get_arr = getu(sys_simplified, [lorenz1.x, lorenz2.x])
get_tuple = getu(sys_simplified, (lorenz1.x, lorenz2.x))
get_obs = getu(sys_simplified, lorenz1.x + lorenz2.x)
get_obs_arr = getu(sys_simplified, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])
get_obs = getu(sol, lorenz1.x + lorenz2.x) # can't use sys for observed
get_obs_arr = getu(sol, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])
l1x_idx = variable_index(sol, lorenz1.x)
l2x_idx = variable_index(sol, lorenz2.x)
l1y_idx = variable_index(sol, lorenz1.y)
Expand Down Expand Up @@ -213,6 +215,7 @@ prob = ODEProblem(sys, [], (0, 1.0))
sol = solve(prob, Tsit5())
@test sol[x] isa Vector{<:Vector}
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
@test sol.ps[p] == [1, 2, 3]

getx = getu(sys, x)
get_mix_arr = getu(sys, [x, y])
Expand Down Expand Up @@ -251,4 +254,6 @@ sol = solve(prob, Tsit5())
@test sol[y]1 atol=1e-3
@test getp(sys, a)(sol) 1
@test getp(sys, b)(sol) 100
@test sol.ps[a] 1
@test sol.ps[b] 100
end

0 comments on commit 17e3481

Please sign in to comment.