Skip to content

Commit

Permalink
feat: support observed generation for array expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 30, 2024
1 parent 632032f commit d141d4b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 22 deletions.
68 changes: 46 additions & 22 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,68 @@ function independent_variable_symbols(sc::SymbolCache)
end
is_observed(sc::SymbolCache, sym) = false
is_observed(::SymbolCache, ::Expr) = true
is_observed(::SymbolCache, ::AbstractArray{Expr}) = true
is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true

struct ExpressionSearcher
declared::Set{Symbol}
fnbody::Expr
end

ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block))

function (exs::ExpressionSearcher)(sys, expr::Expr)
for arg in expr.args
exs(sys, arg)
end
exs(sys, expr.head)
return nothing
end

function (exs::ExpressionSearcher)(sys, sym::Symbol)
sym in exs.declared && return
if is_variable(sys, sym)
idx = variable_index(sys, sym)
push!(exs.fnbody.args, :($sym = u[$idx]))
elseif is_parameter(sys, sym)
idx = parameter_index(sys, sym)
push!(exs.fnbody.args, :($sym = p[$idx]))
elseif is_independent_variable(sys, sym)
push!(exs.fnbody.args, :($sym = t))
end
push!(exs.declared, sym)
return nothing
end

(::ExpressionSearcher)(sys, sym) = nothing

function observed(sc::SymbolCache, expr::Expr)
let cache = Dict{Expr, Function}()
return get!(cache, expr) do
fnbody = Expr(:block)
declared = Set{Symbol}()
MacroTools.postwalk(expr) do sym
sym isa Symbol || return
sym in declared && return
if sc.variables !== nothing &&
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
push!(fnbody.args, :($sym = u[$idx]))
push!(declared, sym)
elseif sc.parameters !== nothing &&
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
push!(fnbody.args, :($sym = p[$idx]))
push!(declared, sym)
elseif sym === sc.independent_variables ||
sc.independent_variables isa Vector &&
sym == only(sc.independent_variables)
push!(fnbody.args, :($sym = t))
push!(declared, sym)
end
end
exs = ExpressionSearcher()
exs(sc, expr)
fnexpr = if is_time_dependent(sc)
:(function (u, p, t)
$fnbody
$(exs.fnbody)
return $expr
end)
else
:(function (u, p)
$fnbody
$(exs.fnbody)
return $expr
end)
end
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
end
end
end
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
end
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
return observed(sc, :(($(exprs...),)))
end

function is_time_dependent(sc::SymbolCache)
sc.independent_variables === nothing && return false
if symbolic_type(sc.independent_variables) == NotSymbolic()
Expand Down
11 changes: 11 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ obsfn = observed(sc, :(x + a + t))
obsfn2 = observed(sc, :(x + a + t))
@test obsfn === obsfn2

@test is_observed(sc, [:(x + a), :(a + t)])
obsfn3 = observed(sc, [:(x + a), :(a + t)])
@test obsfn3(ones(3), 2ones(2), 3.0) [3.0, 5.0]
@test is_observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
@test size(obsfn4(ones(3), 2ones(2), 3.0)) == (2, 2)
@test obsfn4(ones(3), 2ones(2), 3.0) [3.0 3.0; 2.0 4.0]
@test is_observed(sc, (:(x + a), :(y + b)))
obsfn5 = observed(sc, (:(x + a), :(y + b)))
@test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0))

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
Expand Down

0 comments on commit d141d4b

Please sign in to comment.