diff --git a/src/struct.jl b/src/struct.jl index e6ff592f9..7aaa54307 100644 --- a/src/struct.jl +++ b/src/struct.jl @@ -68,9 +68,31 @@ function decodetyp(typ::TypeT) end struct Struct + juliatype::DataType v::Vector{StructElement} end + +""" + symstruct(T) + +Create a symbolic struct from a given type `T`. +""" +function symstruct(T) + elems = map(fieldnames(T)) do fieldname + StructElement(fieldtype(T, fieldname), fieldname) + end |> collect + Struct(T, elems) +end + +""" + juliatype(s::Struct) + +Get the Julia type that `s` is representing. +""" +juliatype(s::Struct) = getfield(s, :juliatype) +getelements(s::Struct) = getfield(s, :v) + function Base.getproperty(s::Struct, name::Symbol) v = getfield(s, :v) idx = findfirst(x -> nameof(x) == name, v) diff --git a/test/struct.jl b/test/struct.jl index b509abc51..c9b178bc6 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -1,5 +1,5 @@ using Test, Symbolics -using Symbolics: StructElement, Struct, operation, arguments +using Symbolics: StructElement, Struct, operation, arguments, symstruct, juliatype handledtypes = [Int8, Int16, @@ -17,17 +17,25 @@ for t in handledtypes end @variables t x(t) +struct Fisk + a::Int8 + b::Int +end a = StructElement(Int8, :a) b = StructElement(Int, :b) -s = Struct([a, b]) -sa = s.a -sb = s.b -@test operation(sa) === getfield -@test arguments(sa) == Any[s, 1] -@test arguments(sa) isa Any -@test operation(sb) === getfield -@test arguments(sb) == Any[s, 2] -@test arguments(sb) isa Any +for s in [Struct(Fisk, [a, b]), symstruct(Fisk)] + sa = s.a + sb = s.b + @test operation(sa) === getfield + @test arguments(sa) == Any[s, 1] + @test arguments(sa) isa Any + @test operation(sb) === getfield + @test arguments(sb) == Any[s, 2] + @test arguments(sb) isa Any + @test juliatype(s) == Fisk +end + +s = Struct(Fisk, [a, b]) sa1 = (setproperty!(s, :a, UInt8(1))) @test operation(sa1) === setfield! @@ -38,3 +46,12 @@ sb1 = (setproperty!(s, :b, "hi")) @test operation(sb1) === setfield! @test arguments(sb1) == Any[s, 2, "hi"] @test arguments(sb1) isa Any + +struct Jörgen + a::Int + b::Float64 +end + +ss = symstruct(Jörgen) + +@test getfield(ss, :v) == [StructElement(Int, :a), StructElement(Float64, :b)]