Skip to content

Commit

Permalink
fix: bugs related to observed
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Feb 6, 2024
1 parent 1cb0ec2 commit d8823c9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
21 changes: 5 additions & 16 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4046,26 +4046,15 @@ end

SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) = has_sys(fn) ? fn.sys : SymbolCache()

SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_observed(fn)
SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys(fn) ? is_observed(fn.sys, sym) : has_observed(fn)

function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym)
if has_observed(fn)
if is_time_dependent(fn)
return if hasmethod(fn.observed, Tuple{typeof(sym)})
fn.observed(sym)
else
let obs = fn.observed, sym = sym
(u, p, t) -> obs(sym, u, p, t)
end
end
if hasmethod(fn.observed, Tuple{Any})
@show sym
return fn.observed(sym)
else
return if hasmethod(fn.observed, Tuple{typeof(sym)})
fn.observed(sym)
else
let obs = fn.observed, sym = sym
(u, p) -> obs(sym, u, p)
end
end
return (args...) -> fn.observed(sym, args...)
end
end
error("SciMLFunction does not have observed")
Expand Down
6 changes: 2 additions & 4 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,12 @@ getx = getu(sprob, x)
gety = getu(sprob, :y)
get_arr = getu(sprob, [x, y])
get_tuple = getu(sprob, (y, z))
# observed doesn't work the same for SDEs
# uncomment get_obs test below when fixed
@test_broken get_obs = getu(sprob, sys.x + sys.z + t + σ)
get_obs = getu(sprob, sys.x + sys.z + t + σ)
@test getx(sprob) == 10.0
@test gety(sprob) == 10.0
@test get_arr(sprob) == [10.0, 10.0]
@test get_tuple(sprob) == (10.0, 1.0)
# @test get_obs(sprob) == 39.0
@test get_obs(sprob) == 22.0

setx! = setu(sprob, x)
sety! = setu(sprob, :y)
Expand Down

0 comments on commit d8823c9

Please sign in to comment.