From c664868b4c86d748776e1a2a9a094a4a20d7c919 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 29 Apr 2024 11:48:37 -0400 Subject: [PATCH 1/4] Batch observed function eval if possible --- src/vector_of_array.jl | 11 +++++++++-- test/downstream/symbol_indexing.jl | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 40a48233..4f8a368a 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -406,13 +406,20 @@ Base.@propagate_inbounds function _getindex( if all(x -> is_parameter(A, x), sym) error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") else - return [getindex.((A,), sym, i) for i in eachindex(A.t)] + return A[sym, eachindex(A.t)] end end Base.@propagate_inbounds function _getindex( A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...) - return reduce(vcat, map(s -> A[s, args...]', sym)) + u = A.u[args...] + t = A.t[args...] + observed_fn = observed(A, sym) + if t isa AbstractArray + return observed_fn.(u, (parameter_values(A),), t) + else + return observed_fn(u, parameter_values(A), t) + end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index e384f714..bcb8dced 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -35,7 +35,7 @@ sol_new = DiffEqArray(sol.u[1:10], @test_throws Exception sol_new[τ] gs, = Zygote.gradient(sol) do sol - sum(sol[fol_separate.x]) + sum(sol[fol_separate.x]) end @test "Symbolic Indexing ADjoint" all(all.(isone, gs.u)) From ea53e1291e1f0d64bb9b66d53be0545f3811bf5f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 12:01:27 +0530 Subject: [PATCH 2/4] refactor: use getu for symbolic indexing, remove implicit LabelledArrays dependency --- Project.toml | 6 +- src/vector_of_array.jl | 80 +++++------------------- test/symbolic_indexing_interface_test.jl | 9 +-- 3 files changed, 21 insertions(+), 74 deletions(-) diff --git a/Project.toml b/Project.toml index b219a05d..7d3016ff 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,6 @@ FastBroadcast = "0.2.8" ForwardDiff = "0.10.19" GPUArraysCore = "0.1.1" IteratorInterfaceExtensions = "1" -LabelledArrays = "1.15" LinearAlgebra = "1.10" Measurements = "2.3" MonteCarloMeasurements = "1.1" @@ -60,7 +59,7 @@ StaticArrays = "1.6" StaticArraysCore = "1.4" Statistics = "1.10" StructArrays = "0.6.11" -SymbolicIndexingInterface = "0.3.2" +SymbolicIndexingInterface = "0.3.19" Tables = "1.11" Test = "1" Tracker = "0.2.15" @@ -72,7 +71,6 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" @@ -88,4 +86,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"] +test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"] diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 4f8a368a..764a42e6 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -352,73 +352,27 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli end # Symbolic Indexing Methods -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym) - if is_independent_variable(A, sym) - return A.t - elseif is_variable(A, sym) - if constant_structure(A) - return getindex.(A.u, variable_index(A, sym)) - else - return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t))) +for symtype in [ScalarSymbolic, ArraySymbolic] + paramcheck = quote + if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym) + error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") end - elseif is_parameter(A, sym) - error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") - elseif is_observed(A, sym) - return observed(A, sym).(A.u, (parameter_values(A),), A.t) - else - # NOTE: this is basically just for LabelledArrays. It's better if this - # were an error. Should we make an extension for LabelledArrays handling - # this case? - return getindex.(A.u, sym) end -end - -Base.@propagate_inbounds function _getindex( - A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...) - if is_independent_variable(A, sym) - return A.t[args...] - elseif is_variable(A, sym) - return A[sym][args...] - elseif is_observed(A, sym) - u = A.u[args...] - t = A.t[args...] - observed_fn = observed(A, sym) - if t isa AbstractArray - return observed_fn.(u, (parameter_values(A),), t) - else - return observed_fn(u, parameter_values(A), t) - end - else - # NOTE: this is basically just for LabelledArrays. It's better if this - # were an error. Should we make an extension for LabelledArrays handling - # this case? - return getindex.(A.u[args...], sym) + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym) + $paramcheck + getu(A, sym)(A) end -end - -Base.@propagate_inbounds function _getindex( - A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...) - return getindex(A, collect(sym), args...) -end - -Base.@propagate_inbounds function _getindex( - A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}) - if all(x -> is_parameter(A, x), sym) - error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") - else - return A[sym, eachindex(A.t)] + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg) + $paramcheck + getu(A, sym)(A, arg) end -end - -Base.@propagate_inbounds function _getindex( - A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...) - u = A.u[args...] - t = A.t[args...] - observed_fn = observed(A, sym) - if t isa AbstractArray - return observed_fn.(u, (parameter_values(A),), t) - else - return observed_fn(u, parameter_values(A), t) + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) + $paramcheck + getu(A, sym).((A,), arg) + end + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon) + $paramcheck + getu(A, sym)(A) end end diff --git a/test/symbolic_indexing_interface_test.jl b/test/symbolic_indexing_interface_test.jl index 74ab04a6..f7105005 100644 --- a/test/symbolic_indexing_interface_test.jl +++ b/test/symbolic_indexing_interface_test.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, Test, LabelledArrays, SymbolicIndexingInterface +using RecursiveArrayTools, Test, SymbolicIndexingInterface t = 0.0:0.1:1.0 f(x) = 2x @@ -20,7 +20,7 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t], @test dx[[:a, :b]] ≈ [[f(x), f2(x)] for x in t] @test dx[(:a, :b)] == [(f(x), f2(x)) for x in t] @test dx[[:a, :b], 3] ≈ [f(t[3]), f2(t[3])] -@test dx[[:a, :b], 4:5] ≈ vcat(f.(t[4:5])', f2.(t[4:5])') +@test dx[[:a, :b], 4:5] ≈ vcat.(f.(t[4:5]), f2.(t[4:5])) @test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]] @test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3] @test getp(dx, [:p, :q])(dx) == [1.0, 2.0] @@ -53,8 +53,3 @@ get_tuple = getu(dx, (:a, :b)) dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b]) @test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym - -ABC = @SLVector (:a, :b, :c); -A = ABC(1, 2, 3); -B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]); -@test getindex(B, :a) == [1, 1] From bb95814a55d65effea166967a186da42439623ad Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 12:02:13 +0530 Subject: [PATCH 3/4] ci: add LabelledArrays RAT testset to downstream --- .github/workflows/Downstream.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 29e23edb..3bbfea5d 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -31,6 +31,7 @@ jobs: - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} - {user: SciML, repo: SciMLSensitivity.jl, group: Core6} + - {user: SciML, repo: LabelledArrays.jl, group: RecursiveArrayTools} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 From d3c2ecd87edb9b38490cb8fa98d551984dd741aa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 13:31:28 +0530 Subject: [PATCH 4/4] test: fix symbol_indexing testset --- test/downstream/symbol_indexing.jl | 5 ++++- test/runtests.jl | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index bcb8dced..5ef7e3ac 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -1,4 +1,5 @@ using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test +using Zygote using ModelingToolkit: t_nounits as t, D_nounits as D include("../testutils.jl") @@ -38,7 +39,9 @@ gs, = Zygote.gradient(sol) do sol sum(sol[fol_separate.x]) end -@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u)) +@testset "Symbolic Indexing ADjoint" begin + @test all(all.(isone, gs.u)) +end # Tables interface test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x])) diff --git a/test/runtests.jl b/test/runtests.jl index bf264273..819e40f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,8 +47,8 @@ end if GROUP == "SymbolicIndexingInterface" || GROUP == "Downstream" if GROUP == "SymbolicIndexingInterface" activate_downstream_env() - @time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl") end + @time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl") end if GROUP == "GPU"