Skip to content

Commit

Permalink
fix: fix observed function generation for systems with inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 12, 2024
1 parent 9050e70 commit 3f543ae
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,17 @@ function build_explicit_observed_function(sys, ts;
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
if isempty(ps) && inputs !== nothing
ps = (:EMPTY,)

Check warning on line 476 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L475-L476

Added lines #L475 - L476 were not covered by tests
end
else
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
end
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
if inputs === nothing
args = [dvs, ps..., ivs...]
else
inputs = unwrap.(inputs)

Check warning on line 485 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L485

Added line #L485 was not covered by tests
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
args = [dvs, ipts, ps..., ivs...]
end
Expand Down
13 changes: 13 additions & 0 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,16 @@ matrices, ssys = linearize(augmented_sys,
# P = ss(A,B,C,0)
# G = ss(matrices...)
# @test sminreal(G[1, 3]) ≈ sminreal(P[1,1])*dist

@testset "Observed functions with inputs" begin
@variables x(t)=0 u(t)=0 [input = true]
eqs = [
D(x) ~ -x + u
]

@named sys = ODESystem(eqs, t)
(; io_sys,) = ModelingToolkit.generate_control_function(sys, simplify = true)
obsfn = ModelingToolkit.build_explicit_observed_function(
io_sys, [x + u * t]; inputs = [u])
@test obsfn([1.0], [2.0], nothing, 3.0) == [7.0]
end

0 comments on commit 3f543ae

Please sign in to comment.