Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph evaluator caching and expression visualiser #98

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: no more need for memoize_on
  • Loading branch information
MilesCranmer committed Jul 8, 2024
commit 8707d24726e42d1f4d0d36339062d95070b4bb9d
97 changes: 0 additions & 97 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,103 +12,6 @@ macro return_on_false2(flag, retval, retval2)
)
end

"""
@memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode)
...
end

This macro takes a function definition and creates a second version of the
function with an additional `id_map` argument. When passed this argument (an
IdDict()), it will use use the `id_map` to avoid recomputing the same value
for the same node in a tree. Use this to automatically create functions that
work with trees that have shared child nodes.

Can optionally take a `postprocess` function, which will be applied to the
result of the function before returning it, taking the result as the
first argument and a boolean for whether the result was memoized as the
second argument. This is useful for functions that need to count the number
of unique nodes in a tree, for example.
"""
macro memoize_on(tree, args...)
if length(args) ∉ (1, 2)
error("Expected 2 or 3 arguments to @memoize_on")
end
postprocess = length(args) == 1 ? :((r, _) -> r) : args[1]
def = length(args) == 1 ? args[1] : args[2]
idmap_def = _memoize_on(tree, postprocess, def)

return quote
$(esc(def)) # The normal function
$(esc(idmap_def)) # The function with an id_map argument
end
end
function _memoize_on(tree::Symbol, postprocess, def)
sdef = splitdef(def)

# Add an id_map argument
push!(sdef[:args], :(id_map::AbstractDict))

f_name = sdef[:name]

# Forward id_map argument to all calls of the same function
# within the function body:
sdef[:body] = postwalk(sdef[:body]) do ex
if @capture(ex, f_(args__))
if f == f_name
return Expr(:call, f, args..., :id_map)
end
end
return ex
end

# Wrap the function body in a get!(id_map, tree) do ... end block:
@gensym key is_memoized result body
sdef[:body] = quote
$key = objectid($tree)
$is_memoized = haskey(id_map, $key)
function $body()
return $(sdef[:body])
end
$result = if $is_memoized
@inbounds(id_map[$key])
else
id_map[$key] = $body()
end
return $postprocess($result, $is_memoized)
end

return combinedef(sdef)
end

"""
@with_memoize(call, id_map)

This simple macro simply puts the `id_map`
into the call, to be consistent with the `@memoize_on` macro.

```
@with_memoize(_copy_node(tree), IdDict{Any,Any}())
````

is converted to

```
_copy_node(tree, IdDict{Any,Any}())
```

"""
macro with_memoize(def, id_map)
idmap_def = _add_idmap_to_call(def, id_map)
return quote
$(esc(idmap_def))
end
end

function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr})
@assert def.head == :call
return Expr(:call, def.args[1], def.args[2:end]..., id_map)
end

@inline function fill_similar(value::T, array, args...) where {T}
out_array = similar(array, args...)
fill!(out_array, value)
Expand Down
83 changes: 60 additions & 23 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import Base:

using DispatchDoctor: @unstable
using Compat: @inline, Returns
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
using ..UtilsModule: Undefined

"""
tree_mapreduce(
Expand Down Expand Up @@ -89,46 +89,83 @@ function tree_mapreduce(
f_leaf::F1,
f_branch::F2,
op::G,
tree::AbstractNode,
tree::AbstractNode{D},
result_type::Type{RT}=Undefined;
f_on_shared::H=(result, is_shared) -> result,
break_sharing::Val=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT}

# Trick taken from here:
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
# to speed up recursive closure
@memoize_on t f_on_shared function inner(inner, t)
if t.degree == 0
return @inline(f_leaf(t))
elseif t.degree == 1
return @inline(op(@inline(f_branch(t)), inner(inner, t.l)))
else
return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r)))
end
end

sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false)
break_sharing::Val{BS}=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS}
sharing = preserve_sharing(typeof(tree)) && !break_sharing

RT == Undefined &&
sharing &&
throw(ArgumentError("Need to specify `result_type` if nodes are shared.."))

if sharing && RT != Undefined
d = allocate_id_map(tree, RT)
return @with_memoize inner(inner, tree) d
id_map = allocate_id_map(tree, RT)
reducer = TreeMapreducer(Val(D), id_map, f_leaf, f_branch, op, f_on_shared)
return reducer(tree)
else
reducer = TreeMapreducer(Val(D), nothing, f_leaf, f_branch, op, f_on_shared)
return reducer(tree)
end
end

struct TreeMapreducer{D,ID,F1<:Function,F2<:Function,G<:Function,H<:Function}
max_degree::Val{D}
id_map::ID
f_leaf::F1
f_branch::F2
op::G
f_on_shared::H
end

@generated function (mapreducer::TreeMapreducer{MAX_DEGREE,ID})(
tree::AbstractNode
) where {MAX_DEGREE,ID}
base_expr = quote
d = tree.degree
Base.Cartesian.@nif(
$(MAX_DEGREE + 1),
d_p_one -> (d_p_one - 1) == d,
d_p_one -> if d_p_one == 1
mapreducer.f_leaf(tree)
else
mapreducer.op(
mapreducer.f_branch(tree),
Base.Cartesian.@ntuple(
d_p_one - 1, i -> mapreducer(tree.children[i][])
)...,
)
end
)
end
if ID <: Nothing
# No sharing of nodes (is a tree, not a graph)
return base_expr
else
return inner(inner, tree)
# Otherwise, we need to cache results in `id_map`
# according to `objectid` of the node
return quote
key = objectid(tree)
is_cached = haskey(mapreducer.id_map, key)
if is_cached
return mapreducer.f_on_shared(@inbounds(mapreducer.id_map[key]), true)
else
res = $base_expr
mapreducer.id_map[key] = res
return mapreducer.f_on_shared(res, false)
end
end
end
end

function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT}
d = Dict{UInt,RT}()
# Preallocate maximum storage (counting with duplicates is fast)
N = length(tree; break_sharing=Val(true))
sizehint!(d, N)
return d
end

# TODO: Raise Julia issue for this.
# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.
Expand Down
70 changes: 0 additions & 70 deletions test/test_graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,76 +120,6 @@ end

@test expr_eql(ex, true_ex)
end

@testset "@memoize_on" begin
ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node(
tree::Node{T}
)::Node{T} where {T}
if tree.degree == 0
if tree.constant
Node(; val=copy(tree.val))
else
Node(T; feature=copy(tree.feature))
end
elseif tree.degree == 1
Node(copy(tree.op), _copy_node(tree.l))
else
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
end
end
true_ex = quote
function _copy_node(tree::Node{T})::Node{T} where {T}
if tree.degree == 0
if tree.constant
Node(; val=copy(tree.val))
else
Node(T; feature=copy(tree.feature))
end
elseif tree.degree == 1
Node(copy(tree.op), _copy_node(tree.l))
else
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
end
end
function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T}
key = objectid(tree)
is_memoized = haskey(id_map, key)
function body()
return begin
if tree.degree == 0
if tree.constant
Node(; val=copy(tree.val))
else
Node(T; feature=copy(tree.feature))
end
elseif tree.degree == 1
Node(copy(tree.op), _copy_node(tree.l, id_map))
else
Node(
copy(tree.op),
_copy_node(tree.l, id_map),
_copy_node(tree.r, id_map),
)
end
end
end
result = if is_memoized
begin
$(Expr(:inbounds, true))
local val = id_map[key]
$(Expr(:inbounds, :pop))
val
end
else
id_map[key] = body()
end
return (((x, _) -> begin
x
end)(result, is_memoized))
end
end
@test expr_eql(ex, true_ex)
end
end

@testset "Operations on graphs" begin
Expand Down