Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use getu for symbolic indexing #369

Merged
merged 4 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
- {user: SciML, repo: LabelledArrays.jl, group: RecursiveArrayTools}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ FastBroadcast = "0.2.8"
ForwardDiff = "0.10.19"
GPUArraysCore = "0.1.1"
IteratorInterfaceExtensions = "1"
LabelledArrays = "1.15"
LinearAlgebra = "1.10"
Measurements = "2.3"
MonteCarloMeasurements = "1.1"
Expand All @@ -60,7 +59,7 @@ StaticArrays = "1.6"
StaticArraysCore = "1.4"
Statistics = "1.10"
StructArrays = "0.6.11"
SymbolicIndexingInterface = "0.3.2"
SymbolicIndexingInterface = "0.3.19"
Tables = "1.11"
Test = "1"
Tracker = "0.2.15"
Expand All @@ -72,7 +71,6 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand All @@ -88,4 +86,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
75 changes: 18 additions & 57 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,67 +352,28 @@
end

# Symbolic Indexing Methods
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
if is_independent_variable(A, sym)
return A.t
elseif is_variable(A, sym)
if constant_structure(A)
return getindex.(A.u, variable_index(A, sym))
else
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
for symtype in [ScalarSymbolic, ArraySymbolic]
paramcheck = quote
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)

Check warning on line 357 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L357

Added line #L357 was not covered by tests
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
end
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
elseif is_observed(A, sym)
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
else
# NOTE: this is basically just for LabelledArrays. It's better if this
# were an error. Should we make an extension for LabelledArrays handling
# this case?
return getindex.(A.u, sym)
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
if is_independent_variable(A, sym)
return A.t[args...]
elseif is_variable(A, sym)
return A[sym][args...]
elseif is_observed(A, sym)
u = A.u[args...]
t = A.t[args...]
observed_fn = observed(A, sym)
if t isa AbstractArray
return observed_fn.(u, (parameter_values(A),), t)
else
return observed_fn(u, parameter_values(A), t)
end
else
# NOTE: this is basically just for LabelledArrays. It's better if this
# were an error. Should we make an extension for LabelledArrays handling
# this case?
return getindex.(A.u[args...], sym)
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
$paramcheck
getu(A, sym)(A)

Check warning on line 363 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L361-L363

Added lines #L361 - L363 were not covered by tests
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
return getindex(A, collect(sym), args...)
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
if all(x -> is_parameter(A, x), sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
else
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
$paramcheck
getu(A, sym)(A, arg)

Check warning on line 367 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L365-L367

Added lines #L365 - L367 were not covered by tests
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
$paramcheck
getu(A, sym).((A,), arg)

Check warning on line 371 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L369-L371

Added lines #L369 - L371 were not covered by tests
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
$paramcheck
getu(A, sym)(A)

Check warning on line 375 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L373-L375

Added lines #L373 - L375 were not covered by tests
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...)
return reduce(vcat, map(s -> A[s, args...]', sym))
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
Expand Down
7 changes: 5 additions & 2 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test
using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D

include("../testutils.jl")
Expand Down Expand Up @@ -35,10 +36,12 @@ sol_new = DiffEqArray(sol.u[1:10],
@test_throws Exception sol_new[τ]

gs, = Zygote.gradient(sol) do sol
sum(sol[fol_separate.x])
sum(sol[fol_separate.x])
end

@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))
@testset "Symbolic Indexing ADjoint" begin
@test all(all.(isone, gs.u))
end

# Tables interface
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
if GROUP == "SymbolicIndexingInterface" || GROUP == "Downstream"
if GROUP == "SymbolicIndexingInterface"
activate_downstream_env()
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
end
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
end

if GROUP == "GPU"
Expand Down
9 changes: 2 additions & 7 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test, LabelledArrays, SymbolicIndexingInterface
using RecursiveArrayTools, Test, SymbolicIndexingInterface

t = 0.0:0.1:1.0
f(x) = 2x
Expand All @@ -20,7 +20,7 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
@test dx[[:a, :b]] ≈ [[f(x), f2(x)] for x in t]
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
@test dx[[:a, :b], 3] ≈ [f(t[3]), f2(t[3])]
@test dx[[:a, :b], 4:5] ≈ vcat(f.(t[4:5])', f2.(t[4:5])')
@test dx[[:a, :b], 4:5] ≈ vcat.(f.(t[4:5]), f2.(t[4:5]))
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
Expand Down Expand Up @@ -53,8 +53,3 @@ get_tuple = getu(dx, (:a, :b))

dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym

ABC = @SLVector (:a, :b, :c);
A = ABC(1, 2, 3);
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
@test getindex(B, :a) == [1, 1]
Loading