Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph evaluator caching and expression visualiser #98

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: various aspects of degree interface
  • Loading branch information
MilesCranmer committed Jul 8, 2024
commit 2a0bd054578c88f47b70b61ad0141e40c8e6ce47
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import .NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
max_degree,
leaf_copy,
branch_copy,
leaf_hash,
Expand Down
2 changes: 1 addition & 1 deletion src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ end
# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}

function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T}
function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
return with_type_parameters(N, T)()
end

Expand Down
11 changes: 11 additions & 0 deletions src/NodeUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D}
end
NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T))

@inline function Base.getproperty(n::NodeIndex, k::Symbol)
if k == :l
# TODO: Should a depwarn be raised here? Or too slow?
return getfield(n, :children)[1][]
elseif k == :r
return getfield(n, :children)[2][]
else
return getfield(n, k)
end
end

# Sharing is never needed for NodeIndex,
# as we trace over the node we are indexing on.
preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false
Expand Down
2 changes: 1 addition & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function tree_mapreduce(
f_on_shared::H=(result, is_shared) -> result,
break_sharing::Val{BS}=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS}
sharing = preserve_sharing(typeof(tree)) && !break_sharing
sharing = preserve_sharing(typeof(tree)) && !BS

RT == Undefined &&
sharing &&
Expand Down
4 changes: 2 additions & 2 deletions test/test_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ end

@testset "collect" begin
ctree = copy(tree)
@test typeof(first(collect(ctree))) == Node{Float64}
@test typeof(first(collect(ctree))) <: Node{Float64}
@test objectid(first(collect(ctree))) == objectid(ctree)
@test objectid(first(collect(ctree))) == objectid(ctree)
@test objectid(first(collect(ctree))) == objectid(ctree)
@test typeof(collect(ctree)) == Vector{Node{Float64}}
@test typeof(collect(ctree)) <: Vector{<:Node{Float64}}
@test length(collect(ctree)) == 24
@test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) ≈ 11.6
end
Expand Down
22 changes: 13 additions & 9 deletions test/test_custom_node_type.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
using DynamicExpressions
using Test

mutable struct MyCustomNode{A,B} <: AbstractNode
mutable struct MyCustomNode{A,B} <: AbstractNode{2}
degree::Int
val1::A
val2::B
l::MyCustomNode{A,B}
r::MyCustomNode{A,B}
children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}}

MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2)
MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l)
MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r)
function MyCustomNode(val1, val2, l)
return new{typeof(val1),typeof(val2)}(
1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}())
)
end
function MyCustomNode(val1, val2, l, r)
return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r)))
end
end

node1 = MyCustomNode(1.0, 2)
Expand All @@ -24,7 +29,7 @@ node2 = MyCustomNode(1.5, 3, node1)

@test typeof(node2) == MyCustomNode{Float64,Int}
@test node2.degree == 1
@test node2.l.degree == 0
@test node2.children[1][].degree == 0
@test count_depth(node2) == 2
@test count_nodes(node2) == 2

Expand All @@ -37,14 +42,13 @@ node2 = MyCustomNode(1.5, 3, node1, node1)
@test count(t -> t.degree == 0, node2) == 2

# If we have a bad definition, it should get caught with a helpful message
mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T}
mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2}
degree::UInt8
constant::Bool
val::T
feature::UInt16
op::UInt8
l::MyCustomNode2{T}
r::MyCustomNode2{T}
children::NTuple{2,Base.RefValue{MyCustomNode2{T}}}
end

@test_throws ErrorException MyCustomNode2()
Expand Down
4 changes: 2 additions & 2 deletions test/test_equality.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2)

f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1))
f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1))
@test typeof(f64_tree) == GraphNode{Float64}
@test typeof(f32_tree) == GraphNode{Float32}
@test typeof(f64_tree) <: GraphNode{Float64}
@test typeof(f32_tree) <: GraphNode{Float32}

@test convert(GraphNode{Float64}, f32_tree) == f64_tree

Expand Down
25 changes: 16 additions & 9 deletions test/test_extra_node_fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,31 @@

using Test
using DynamicExpressions
using DynamicExpressions: constructorof
using DynamicExpressions: constructorof, max_degree

mutable struct FrozenNode{T} <: AbstractExpressionNode{T}
mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D}
degree::UInt8
constant::Bool
val::T
frozen::Bool # Extra field!
feature::UInt16
op::UInt8
l::FrozenNode{T}
r::FrozenNode{T}
children::NTuple{D,Base.RefValue{FrozenNode{T,D}}}

function FrozenNode{_T}() where {_T}
n = new{_T}()
function FrozenNode{_T,_D}() where {_T,_D}
n = new{_T,_D}()
n.frozen = false
return n
end
end
function DynamicExpressions.constructorof(::Type{N}) where {N<:FrozenNode}
return FrozenNode{T,max_degree(N)} where {T}
end
function DynamicExpressions.with_type_parameters(
::Type{N}, ::Type{T}
) where {T,N<:FrozenNode}
return FrozenNode{T,max_degree(N)}
end
function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T}
out = if t.constant
constructorof(typeof(t))(; val=t.val)
Expand Down Expand Up @@ -56,7 +63,7 @@ function DynamicExpressions.leaf_equal(a::FrozenNode, b::FrozenNode)
end
end

n = let n = FrozenNode{Float64}()
n = let n = FrozenNode{Float64,2}()
n.degree = 0
n.constant = true
n.val = 0.0
Expand Down Expand Up @@ -92,5 +99,5 @@ ex = parse_expression(

@test string_tree(ex) == "x + sin(y + 2.1)"
@test ex.tree.frozen == false
@test ex.tree.r.frozen == true
@test ex.tree.r.l.frozen == false
@test ex.tree.children[2][].frozen == true
@test ex.tree.children[2][].children[1][].frozen == false
13 changes: 1 addition & 12 deletions test/test_graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,6 @@ end
:(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())),
)
end

@testset "@with_memoize" begin
ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize(
_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}()
)
true_ex = quote
_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())
end

@test expr_eql(ex, true_ex)
end
end

@testset "Operations on graphs" begin
Expand Down Expand Up @@ -283,7 +272,7 @@ end
x = GraphNode(Float32; feature=1)
tree = x + 1.0
@test tree.l === x
@test typeof(tree) === GraphNode{Float32}
@test typeof(tree) <: GraphNode{Float32}

# Detect error from Float32(1im)
@test_throws InexactError x + 1im
Expand Down
4 changes: 2 additions & 2 deletions test/test_parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ end
variable_names = ["x"],
)

@test typeof(ex.tree) === Node{Any}
@test typeof(ex.tree) <: Node{Any}
@test typeof(ex.metadata.operators) <: GenericOperatorEnum
s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex)
@test s == "[1, 2, 3] * tan(cos(5.0 + x))"
Expand Down Expand Up @@ -184,7 +184,7 @@ end
s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex)
@test s == "(x * 2.5) - cos(y)"
end
@test contains(logged_out, "Node{Float32}")
@test contains(logged_out, "Node{Float32")
end

@testitem "Helpful errors for missing operator" begin
Expand Down
Loading