From 2570016e1b6115fa3794837639a435e733e588bc Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 14 Mar 2024 11:01:55 +0100 Subject: [PATCH 1/2] fix: handle type propagation in getfield --- src/struct.jl | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/struct.jl b/src/struct.jl index 421a10968..9e0ac5b1a 100644 --- a/src/struct.jl +++ b/src/struct.jl @@ -23,12 +23,17 @@ juliatype(::Type{Struct{T}}) where T = T getelements(s::Type{<:Struct}) = fieldnames(juliatype(s)) getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s)) +typed_getfield(obj, ::Val{fieldname}) where fieldname = getfield(obj, fieldname) + function symbolic_getproperty(ss, name::Symbol) s = symtype(ss) idx = findfirst(isequal(name), getelements(s)) idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!") T = getelementtypes(s)[idx] - SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T) + if isstructtype(T) + T = Struct{T} + end + SymbolicUtils.term(typed_getfield, ss, Val{name}(), type = T) end function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol) wrap(symbolic_getproperty(unwrap(s), name)) @@ -54,6 +59,18 @@ end # We cannot precisely derive the type after `getfield` due to SU limitations, # so give up and just say Real. -SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real +function SymbolicUtils.promote_symtype(::typeof(typed_getfield), ::Type{<:Struct{T}}, v::Type{Val{fieldname}}) where {T, fieldname} + FT = fieldtype(T, fieldname) + if isstructtype(FT) + return Struct{FT} + end + FT +end + + +SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T +function SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T + s +end + SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T -SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s From 1b4bd50d08503616682d0a3a29cc5db36d856f03 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 14 Mar 2024 11:15:49 +0100 Subject: [PATCH 2/2] update test --- test/struct.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/struct.jl b/test/struct.jl index 3e842932c..8eef42dc1 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -10,8 +10,8 @@ S = symstruct(Jörgen) @variables x::S xa = Symbolics.unwrap(symbolic_getproperty(x, :a)) @test Symbolics.symtype(xa) == Int -@test Symbolics.operation(xa) == getfield -@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a)]) +@test Symbolics.operation(xa) == Symbolics.typed_getfield +@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Val{:a}()]) xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10)) @test Symbolics.operation(xa) == setfield! @test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10])