Skip to content

Commit

Permalink
wip: add StructIO
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Feb 16, 2024
1 parent 5ee8170 commit 08742b6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/Blocks/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,41 @@ Base class for a multiple input multiple output (MIMO) continuous system block.
]
return ODESystem(eqs, t, vcat(u..., y...), []; name = name, systems = [input, output])
end



@connector function StructInput(; structdef, name)
n = Symbolics.juliatype(structdef)
@variables u(t)::n [input = true] # Dummy default value due to bug in Symbolics
ODESystem(Equation[], t, [u], []; name)

Check warning on line 116 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L113-L116

Added lines #L113 - L116 were not covered by tests
end

@connector function StructOutput(; structdef, name)
n = Symbolics.juliatype(structdef)
@variables u(t)::n [output = true] # Dummy default value due to bug in Symbolics
ODESystem(Equation[], t, [u], []; name)

Check warning on line 122 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L119-L122

Added lines #L119 - L122 were not covered by tests
end

function _structelem2connector(elem::StructElement)
T = Symbolics.decodetyp(elem.typ)
if T <: Bool
return BoolOutput(; name = elem.name)
elseif T <: Real
return RealOutput(; name = elem.name)

Check warning on line 130 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L125-L130

Added lines #L125 - L130 were not covered by tests
end
end

@component function BusSelect(;name, structdef, selected_fields)
@parameters t
nout = length(selected_fields)
inputbus = Blocks.StructInput(; structdef, name = Symbol("inputbus"))
@variables input(t)

Check warning on line 138 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L134-L138

Added lines #L134 - L138 were not covered by tests

output_elements = filter(e->e.name in selected_fields, getelements(structdef))
output_connectors = map(_structelem2connector, output_elements)

Check warning on line 141 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L140-L141

Added lines #L140 - L141 were not covered by tests

eqs = [

Check warning on line 143 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L143

Added line #L143 was not covered by tests
getproperty(inputbus.u, field) ~ con.u for (field, con) in zip(selected_fields, output_connectors)
]
return ODESystem(eqs, t; name = name, systems = [inputbus; output_connectors])

Check warning on line 146 in src/Blocks/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Blocks/utils.jl#L146

Added line #L146 was not covered by tests
end
29 changes: 29 additions & 0 deletions test/Blocks/sources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,32 @@ end
@test sol[ddy][end]2 atol=1e-3
end
end

using Symbolics
using Symbolics: Struct, StructElement, getelements, symstruct
using Test
using ModelingToolkitStandardLibrary.Blocks
using ModelingToolkitStandardLibrary.Blocks: structelem2connector

# Test struct
struct BarStruct
speed::Float64
isSpeedValid::Int
end

bar = BarStruct(1.0, 1)
structdef = symstruct(BarStruct)
selected_fields = [:speed]

@mtkmodel BusSelectTest begin
@components begin
inputbus = Blocks.StructOutput(; structdef)
output = BusSelect(; structdef, selected_fields)
end
@equations begin
inputbus.u ~ bar
connect(inputbus, output)
end
end

@named sys = BusSelectTest()

0 comments on commit 08742b6

Please sign in to comment.