Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph evaluator caching and expression visualiser #98

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"

[extensions]
DynamicExpressionsBumperExt = "Bumper"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
DynamicExpressionsOptimExt = "Optim"
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
DynamicExpressionsVisualizeExt = ["Plots","GraphRecipes"]

[compat]
Bumper = "0.6"
ChainRulesCore = "1"
Compat = "3.37, 4"
DispatchDoctor = "0.4"
Interfaces = "0.3"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1"
PackageExtensionCompat = "1"
Expand All @@ -47,7 +49,6 @@ julia = "1.6"

[extras]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
66 changes: 64 additions & 2 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module DynamicExpressionsLoopVectorizationExt

using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions

using LoopVectorization: @turbo, vmapnt
using DynamicExpressions: AbstractExpressionNode, GraphNode, OperatorEnum
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
import DynamicExpressions.EvaluateModule:
Expand All @@ -14,6 +16,7 @@ import DynamicExpressions.EvaluateModule:
deg2_r0_eval
import DynamicExpressions.ExtensionInterfaceModule:
_is_loopvectorization_loaded, bumper_kern1!, bumper_kern2!
import DynamicExpressions.ValueInterfaceModule: is_valid, is_valid_array

_is_loopvectorization_loaded(::Int) = true

Expand Down Expand Up @@ -230,4 +233,63 @@ function bumper_kern2!(
return cumulator1
end



# graph eval

function DynamicExpressions.EvaluateModule._eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{true}
) where {T}

# vmap is faster with small cX sizes
# vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?)

order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
node.cache = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(operators.unaops[node.op], node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
else
if node.r.constant
node.constant = false
node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
else
node.constant = false
node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
end
end
end
if root.constant
return ResultOk(fill(root.val, size(cX, 2)), true)
else
return ResultOk(root.cache, true)
end
end

end
72 changes: 72 additions & 0 deletions ext/DynamicExpressionsVisualizeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module DynamicExpressionsVisualizeExt

using Plots, GraphRecipes, DynamicExpressions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to use https://github.com/JuliaPlots/Plots.jl/tree/master/RecipesBase so we don't need to directly depend on Plots.jl? If using RecipesBase.jl it can usually interface with other plotting packages.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually is Plots.jl used here?

using DynamicExpressions: GraphNode, Node, topological_sort, AbstractOperatorEnum, get_op_name

function DynamicExpressions.visualize(
graph::Union{GraphNode,Node}, # types accepted by topological_sort
operators::AbstractOperatorEnum,
show = true
)
@info "Generating graph visualization"

order = reverse(topological_sort(graph))

# multigraph adjacency list
g = map(
node -> convert(Vector{Int64}, map(
cindex -> findfirst(x -> x === node.children[cindex][], order),
1:node.degree
)),
order
)

# node labels
n = map(x ->
if x.degree == 0
x.constant ? x.val : 'x' * string(x.feature)
elseif x.degree == 1
join(get_op_name(operators.unaops[x.op]))
elseif x.degree == 2
join(get_op_name(operators.binops[x.op]))
else
@warn "Can't label operator node with degree > 2"
end,
order
)

# edge labels (specifies parameter no.)
e = Dict{Tuple{Int64, Int64, Int64}, String}()
for (index, node) in enumerate(order)
edge_count = Dict{Int64, Int64}() # count number of edges to each child node
for cindex in 1:node.degree
order_cindex = findfirst(x -> x === node.children[cindex][], order)
get!(
e,
(
index, # source
order_cindex, # dest
get!(edge_count, order_cindex, pop!(edge_count, order_cindex, 0)+1) # edge no.
),
string(cindex)
)
end
end

# node colours
c = map(x -> x == 1 ? 2 : 1, eachindex(order))

return graphplot(
g,
names = n,
edgelabel = e,
nodecolor = c,
show = show,
nodeshape=:circle,
edge_label_box = false,
edgelabel_offset = 0.015,
nodesize=0.15
)
end

end
11 changes: 8 additions & 3 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using DispatchDoctor: @stable, @unstable
include("Random.jl")
include("Parse.jl")
include("ParametricExpression.jl")
include("Visualize.jl")
include("StructuredExpression.jl")
end

Expand All @@ -44,11 +45,14 @@ import .ValueInterfaceModule:
set_node!,
tree_mapreduce,
filter_map,
filter_map!
filter_map!,
topological_sort,
randomised_topological_sort
import .NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
max_degree,
leaf_copy,
branch_copy,
leaf_hash,
Expand All @@ -66,8 +70,7 @@ import .NodeModule:
count_scalar_constants,
get_scalar_constants,
set_scalar_constants!
@reexport import .StringsModule: string_tree, print_tree
import .StringsModule: get_op_name
@reexport import .StringsModule: string_tree, print_tree, get_op_name
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
Expand All @@ -93,6 +96,7 @@ import .ExpressionModule:
@reexport import .ParseModule: @parse_expression, parse_expression
import .ParseModule: parse_leaf
@reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode
@reexport import .VisualizeModule: visualize
@reexport import .StructuredExpressionModule: StructuredExpression

@stable default_mode = "disable" begin
Expand All @@ -104,6 +108,7 @@ end
import .InterfacesModule:
ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except


function __init__()
@require_extensions
end
Expand Down
Loading
Loading