Skip to content

Commit

Permalink
Add struct tracing
Browse files Browse the repository at this point in the history
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
  • Loading branch information
YingboMa and baggepinnen committed Jan 25, 2024
1 parent 1bdc3c9 commit 10219bf
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ using PrecompileTools
using Setfield

import DomainSets: Domain

import SymbolicUtils: similarterm, istree, operation, arguments, symtype, metadata

import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic,
FnType, @rule, Rewriters, substitute,
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv

using SymbolicUtils.Code

import SymbolicUtils.Rewriters: Chain, Prewalk, Postwalk, Fixpoint

import SymbolicUtils.Code: toexpr

import ArrayInterface
using RuntimeGeneratedFunctions
using SciMLBase, IfElse
Expand Down Expand Up @@ -145,6 +145,7 @@ include("parsing.jl")
export parse_expr_to_symbolic

include("error_hints.jl")
include("struct.jl")

# Hacks to make wrappers "nicer"
const NumberTypes = Union{AbstractFloat,Integer,Complex{<:AbstractFloat},Complex{<:Integer}}
Expand Down
92 changes: 92 additions & 0 deletions src/struct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
const TypeT = UInt32
const ISINTEGER = TypeT(0)
const SIGNED_OFFSET = TypeT(1)
const SIZE_OFFSET = TypeT(2)

const EMPTY_DIMS = Int[]

struct StructElement
typ::TypeT
name::Symbol
size::Vector{Int}
function StructElement(::Type{T}, name, size = EMPTY_DIMS) where {T}
c = encodetyp(T)
c == typemax(TypeT) && error("Cannot handle type $T")
new(c, name, size)
end
end

_sizeofrepr(typ::TypeT) = typ >> SIZE_OFFSET
sizeofrepr(s::StructElement) = _sizeofrepr(s.typ)
Base.size(s::StructElement) = s.size
Base.length(s::StructElement) = prod(size(s))
Base.nameof(s::StructElement) = s.name
function Base.show(io::IO, s::StructElement)
print(io, nameof(s), "::", decodetyp(s.typ))
if length(s) > 1
print(io, "::(", join(size(s), " × "), ")")
end
end

function encodetyp(::Type{T}) where {T}
typ = zero(UInt32)
if T <: Integer
typ |= TypeT(1) << ISINTEGER
if T <: Signed
typ |= TypeT(1) << SIGNED_OFFSET
elseif !(T <: Unsigned)
return typemax(TypeT)
end
elseif !(T <: AbstractFloat)
return typemax(TypeT)
end
typ |= TypeT(sizeof(T)) << SIZE_OFFSET
end

function decodetyp(typ::TypeT)
siz = TypeT(8) * (typ >> SIZE_OFFSET)

Check warning on line 47 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
if !iszero(typ & (TypeT(1) << ISINTEGER))
if !iszero(typ & TypeT(1) << SIGNED_OFFSET)
siz == 8 ? Int8 :

Check warning on line 50 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 16 ? Int16 :

Check warning on line 51 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 32 ? Int32 :

Check warning on line 52 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 64 ? Int64 :

Check warning on line 53 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
error("invalid type $(typ)!")
else # unsigned
siz == 8 ? UInt8 :

Check warning on line 56 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 16 ? UInt16 :

Check warning on line 57 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 32 ? UInt32 :

Check warning on line 58 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 64 ? UInt64 :

Check warning on line 59 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
error("invalid type $(typ)!")
end
else # float
siz == 16 ? Float16 :

Check warning on line 63 in src/struct.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"siz" should be "six" or "size".
siz == 32 ? Float32 :
siz == 64 ? Float64 :
error("invalid type $(typ)!")
end
end

struct Struct
v::Vector{StructElement}
end

function Base.getproperty(s::Struct, name::Symbol)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
SymbolicUtils.term(getfield, s, idx, type = Real)
end

function Base.setproperty!(s::Struct, name::Symbol, x)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
type = SymbolicUtils.symtype(x)
SymbolicUtils.term(setfield!, s, idx, x; type)
end

# We cannot precisely derive the type after `getfield` due to SU limitations,
# so give up and just say Real.
SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real
SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a
@register_symbolic limit(a, N)::Integer

if GROUP == "All" || GROUP == "Core"
@safetestset "Struct Test" begin include("struct.jl") end
@safetestset "Macro Test" begin include("macro.jl") end
@safetestset "Arrays" begin include("arrays.jl") end
@safetestset "View-setting" begin include("stencils.jl") end
Expand Down
40 changes: 40 additions & 0 deletions test/struct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Test, Symbolics
using Symbolics: StructElement, Struct, operation, arguments

handledtypes = [Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64]
for t in handledtypes
@test Symbolics.decodetyp(Symbolics.encodetyp(t)) === t
end

@variables t x(t)
a = StructElement(Int8, :a)
b = StructElement(Int, :b)
s = Struct([a, b])
sa = s.a
sb = s.b
@test operation(sa) === getfield
@test arguments(sa) == Any[s, 1]
@test arguments(sa) isa Any
@test operation(sb) === getfield
@test arguments(sb) == Any[s, 2]
@test arguments(sb) isa Any

sa1 = (setproperty!(s, :a, UInt8(1)))
@test operation(sa1) === setfield!
@test arguments(sa1) == Any[s, 1, UInt8(1)]
@test arguments(sa1) isa Any

sb1 = (setproperty!(s, :b, "hi"))
@test operation(sb1) === setfield!
@test arguments(sb1) == Any[s, 2, "hi"]
@test arguments(sb1) isa Any

0 comments on commit 10219bf

Please sign in to comment.