Skip to content

Commit

Permalink
Merge pull request #369 from AayushSabharwal/myb/batch
Browse files Browse the repository at this point in the history
refactor: use getu for symbolic indexing
  • Loading branch information
ChrisRackauckas authored Apr 30, 2024
2 parents b568dde + d3c2ecd commit 9508294
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 71 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
- {user: SciML, repo: LabelledArrays.jl, group: RecursiveArrayTools}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ FastBroadcast = "0.2.8"
ForwardDiff = "0.10.19"
GPUArraysCore = "0.1.1"
IteratorInterfaceExtensions = "1"
LabelledArrays = "1.15"
LinearAlgebra = "1.10"
Measurements = "2.3"
MonteCarloMeasurements = "1.1"
Expand All @@ -60,7 +59,7 @@ StaticArrays = "1.6"
StaticArraysCore = "1.4"
Statistics = "1.10"
StructArrays = "0.6.11"
SymbolicIndexingInterface = "0.3.2"
SymbolicIndexingInterface = "0.3.19"
Tables = "1.11"
Test = "1"
Tracker = "0.2.15"
Expand All @@ -72,7 +71,6 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand All @@ -88,4 +86,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
75 changes: 18 additions & 57 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,67 +352,28 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
end

# Symbolic Indexing Methods
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
if is_independent_variable(A, sym)
return A.t
elseif is_variable(A, sym)
if constant_structure(A)
return getindex.(A.u, variable_index(A, sym))
else
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
for symtype in [ScalarSymbolic, ArraySymbolic]
paramcheck = quote
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
end
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
elseif is_observed(A, sym)
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
else
# NOTE: this is basically just for LabelledArrays. It's better if this
# were an error. Should we make an extension for LabelledArrays handling
# this case?
return getindex.(A.u, sym)
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
if is_independent_variable(A, sym)
return A.t[args...]
elseif is_variable(A, sym)
return A[sym][args...]
elseif is_observed(A, sym)
u = A.u[args...]
t = A.t[args...]
observed_fn = observed(A, sym)
if t isa AbstractArray
return observed_fn.(u, (parameter_values(A),), t)
else
return observed_fn(u, parameter_values(A), t)
end
else
# NOTE: this is basically just for LabelledArrays. It's better if this
# were an error. Should we make an extension for LabelledArrays handling
# this case?
return getindex.(A.u[args...], sym)
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
$paramcheck
getu(A, sym)(A)
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
return getindex(A, collect(sym), args...)
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
if all(x -> is_parameter(A, x), sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
else
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
$paramcheck
getu(A, sym)(A, arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
$paramcheck
getu(A, sym).((A,), arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
$paramcheck
getu(A, sym)(A)
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...)
return reduce(vcat, map(s -> A[s, args...]', sym))
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
Expand Down
7 changes: 5 additions & 2 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test
using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D

include("../testutils.jl")
Expand Down Expand Up @@ -35,10 +36,12 @@ sol_new = DiffEqArray(sol.u[1:10],
@test_throws Exception sol_new[τ]

gs, = Zygote.gradient(sol) do sol
sum(sol[fol_separate.x])
sum(sol[fol_separate.x])
end

@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))
@testset "Symbolic Indexing ADjoint" begin
@test all(all.(isone, gs.u))
end

# Tables interface
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
if GROUP == "SymbolicIndexingInterface" || GROUP == "Downstream"
if GROUP == "SymbolicIndexingInterface"
activate_downstream_env()
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
end
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
end

if GROUP == "GPU"
Expand Down
9 changes: 2 additions & 7 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test, LabelledArrays, SymbolicIndexingInterface
using RecursiveArrayTools, Test, SymbolicIndexingInterface

t = 0.0:0.1:1.0
f(x) = 2x
Expand All @@ -20,7 +20,7 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
@test dx[[:a, :b]] [[f(x), f2(x)] for x in t]
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
@test dx[[:a, :b], 4:5] vcat.(f.(t[4:5]), f2.(t[4:5]))
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
Expand Down Expand Up @@ -53,8 +53,3 @@ get_tuple = getu(dx, (:a, :b))

dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym

ABC = @SLVector (:a, :b, :c);
A = ABC(1, 2, 3);
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
@test getindex(B, :a) == [1, 1]

0 comments on commit 9508294

Please sign in to comment.