Skip to content

Commit

Permalink
Merge pull request #1091 from JuliaSymbolics/fb/getfieldtype
Browse files Browse the repository at this point in the history
fix: handle type propagation in getfield
  • Loading branch information
YingboMa authored Mar 14, 2024
2 parents fc74b9a + 1b4bd50 commit b26b0ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
23 changes: 20 additions & 3 deletions src/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ juliatype(::Type{Struct{T}}) where T = T
getelements(s::Type{<:Struct}) = fieldnames(juliatype(s))
getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s))

typed_getfield(obj, ::Val{fieldname}) where fieldname = getfield(obj, fieldname)

function symbolic_getproperty(ss, name::Symbol)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T)
if isstructtype(T)
T = Struct{T}
end
SymbolicUtils.term(typed_getfield, ss, Val{name}(), type = T)
end
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
Expand All @@ -54,6 +59,18 @@ end

# We cannot precisely derive the type after `getfield` due to SU limitations,
# so give up and just say Real.
SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real
function SymbolicUtils.promote_symtype(::typeof(typed_getfield), ::Type{<:Struct{T}}, v::Type{Val{fieldname}}) where {T, fieldname}
FT = fieldtype(T, fieldname)
if isstructtype(FT)
return Struct{FT}
end
FT
end


SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
function SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T
s
end

SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s
4 changes: 2 additions & 2 deletions test/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ S = symstruct(Jörgen)
@variables x::S
xa = Symbolics.unwrap(symbolic_getproperty(x, :a))
@test Symbolics.symtype(xa) == Int
@test Symbolics.operation(xa) == getfield
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a)])
@test Symbolics.operation(xa) == Symbolics.typed_getfield
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Val{:a}()])
xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10))
@test Symbolics.operation(xa) == setfield!
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10])
Expand Down

0 comments on commit b26b0ab

Please sign in to comment.