Skip to content

Commit

Permalink
Merge pull request #70 from SciML/as/array-observed
Browse files Browse the repository at this point in the history
feat: support observed generation for array expressions
  • Loading branch information
ChrisRackauckas authored Apr 30, 2024
2 parents 5a78f02 + 3cae58c commit c6bf421
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 25 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ version = "0.3.18"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
Pkg = "1"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
Expand Down
1 change: 0 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module SymbolicIndexingInterface

import MacroTools
using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
Expand Down
68 changes: 46 additions & 22 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,68 @@ function independent_variable_symbols(sc::SymbolCache)
end
is_observed(sc::SymbolCache, sym) = false
is_observed(::SymbolCache, ::Expr) = true
is_observed(::SymbolCache, ::AbstractArray{Expr}) = true
is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true

struct ExpressionSearcher
declared::Set{Symbol}
fnbody::Expr
end

ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block))

function (exs::ExpressionSearcher)(sys, expr::Expr)
for arg in expr.args
exs(sys, arg)
end
exs(sys, expr.head)
return nothing
end

function (exs::ExpressionSearcher)(sys, sym::Symbol)
sym in exs.declared && return
if is_variable(sys, sym)
idx = variable_index(sys, sym)
push!(exs.fnbody.args, :($sym = u[$idx]))
elseif is_parameter(sys, sym)
idx = parameter_index(sys, sym)
push!(exs.fnbody.args, :($sym = p[$idx]))
elseif is_independent_variable(sys, sym)
push!(exs.fnbody.args, :($sym = t))
end
push!(exs.declared, sym)
return nothing
end

(::ExpressionSearcher)(sys, sym) = nothing

function observed(sc::SymbolCache, expr::Expr)
let cache = Dict{Expr, Function}()
return get!(cache, expr) do
fnbody = Expr(:block)
declared = Set{Symbol}()
MacroTools.postwalk(expr) do sym
sym isa Symbol || return
sym in declared && return
if sc.variables !== nothing &&
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
push!(fnbody.args, :($sym = u[$idx]))
push!(declared, sym)
elseif sc.parameters !== nothing &&
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
push!(fnbody.args, :($sym = p[$idx]))
push!(declared, sym)
elseif sym === sc.independent_variables ||
sc.independent_variables isa Vector &&
sym == only(sc.independent_variables)
push!(fnbody.args, :($sym = t))
push!(declared, sym)
end
end
exs = ExpressionSearcher()
exs(sc, expr)
fnexpr = if is_time_dependent(sc)
:(function (u, p, t)
$fnbody
$(exs.fnbody)
return $expr
end)
else
:(function (u, p)
$fnbody
$(exs.fnbody)
return $expr
end)
end
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
end
end
end
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
end
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
return observed(sc, :(($(exprs...),)))
end

function is_time_dependent(sc::SymbolCache)
sc.independent_variables === nothing && return false
if symbolic_type(sc.independent_variables) == NotSymbolic()
Expand Down
11 changes: 11 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ obsfn = observed(sc, :(x + a + t))
obsfn2 = observed(sc, :(x + a + t))
@test obsfn === obsfn2

@test is_observed(sc, [:(x + a), :(a + t)])
obsfn3 = observed(sc, [:(x + a), :(a + t)])
@test obsfn3(ones(3), 2ones(2), 3.0) [3.0, 5.0]
@test is_observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
@test size(obsfn4(ones(3), 2ones(2), 3.0)) == (2, 2)
@test obsfn4(ones(3), 2ones(2), 3.0) [3.0 3.0; 2.0 4.0]
@test is_observed(sc, (:(x + a), :(y + b)))
obsfn5 = observed(sc, (:(x + a), :(y + b)))
@test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0))

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
Expand Down

0 comments on commit c6bf421

Please sign in to comment.