From d8823c9b573d952da7062a8e233fb8e5c5ac0a0d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Feb 2024 16:55:51 +0530 Subject: [PATCH] fix: bugs related to `observed` --- src/scimlfunctions.jl | 21 +++++---------------- test/downstream/problem_interface.jl | 6 ++---- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index d3464cdcd..aca1b304c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4046,26 +4046,15 @@ end SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) = has_sys(fn) ? fn.sys : SymbolCache() -SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_observed(fn) +SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys(fn) ? is_observed(fn.sys, sym) : has_observed(fn) function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) - if is_time_dependent(fn) - return if hasmethod(fn.observed, Tuple{typeof(sym)}) - fn.observed(sym) - else - let obs = fn.observed, sym = sym - (u, p, t) -> obs(sym, u, p, t) - end - end + if hasmethod(fn.observed, Tuple{Any}) + @show sym + return fn.observed(sym) else - return if hasmethod(fn.observed, Tuple{typeof(sym)}) - fn.observed(sym) - else - let obs = fn.observed, sym = sym - (u, p) -> obs(sym, u, p) - end - end + return (args...) -> fn.observed(sym, args...) end end error("SciMLFunction does not have observed") diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 15f587d64..90308ee26 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -140,14 +140,12 @@ getx = getu(sprob, x) gety = getu(sprob, :y) get_arr = getu(sprob, [x, y]) get_tuple = getu(sprob, (y, z)) -# observed doesn't work the same for SDEs -# uncomment get_obs test below when fixed -@test_broken get_obs = getu(sprob, sys.x + sys.z + t + σ) +get_obs = getu(sprob, sys.x + sys.z + t + σ) @test getx(sprob) == 10.0 @test gety(sprob) == 10.0 @test get_arr(sprob) == [10.0, 10.0] @test get_tuple(sprob) == (10.0, 1.0) -# @test get_obs(sprob) == 39.0 +@test get_obs(sprob) == 22.0 setx! = setu(sprob, x) sety! = setu(sprob, :y)