Skip to content

Commit

Permalink
Merge pull request #114 from SymbolicML/allocs-functions
Browse files Browse the repository at this point in the history
Create preallocation utility functions for expressions
  • Loading branch information
MilesCranmer authored Dec 13, 2024
2 parents e7955d6 + 16c5ef0 commit 46eb0de
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 93 deletions.
3 changes: 2 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable
include("OperatorEnum.jl")
include("Node.jl")
include("NodeUtils.jl")
include("NodePreallocation.jl")
include("Strings.jl")
include("Evaluate.jl")
include("EvaluateDerivative.jl")
Expand Down Expand Up @@ -41,11 +42,11 @@ import .ValueInterfaceModule:
GraphNode,
Node,
copy_node,
copy_node!,
set_node!,
tree_mapreduce,
filter_map,
filter_map!
import .NodePreallocationModule: allocate_container, copy_into!
import .NodeModule:
constructorof,
with_type_parameters,
Expand Down
20 changes: 20 additions & 0 deletions src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import ..NodeUtilsModule:
count_scalar_constants,
get_scalar_constants,
set_scalar_constants!
import ..NodePreallocationModule: copy_into!, allocate_container
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
import ..EvaluateDerivativeModule: eval_grad_tree_array
import ..EvaluationHelpersModule: _grad_evaluator
Expand Down Expand Up @@ -502,4 +503,23 @@ function (ex::AbstractExpression)(
return get_tree(ex)(X, get_operators(ex, operators); kws...)
end

# We don't require users to overload this, as it's not part of the required interface.
# Also, there's no way to generally do this from the required interface, so for backwards
# compatibility, we just return nothing.
# COV_EXCL_START
function copy_into!(::Nothing, src::AbstractExpression)
return copy(src)
end
function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
return nothing
end
# COV_EXCL_STOP
function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing)
return (; tree=allocate_container(get_contents(prototype), n))
end
function copy_into!(dest::NamedTuple, src::Expression)
tree = copy_into!(dest.tree, get_contents(src))
return with_contents(src, tree)
end

end
32 changes: 22 additions & 10 deletions src/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ using ..NodeModule:
default_allocator,
with_type_parameters,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal,
branch_copy,
branch_copy!,
branch_convert,
branch_hash,
branch_equal,
Expand All @@ -38,6 +36,8 @@ using ..NodeUtilsModule:
has_constants,
get_scalar_constants,
set_scalar_constants!
using ..NodePreallocationModule:
copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container
using ..StringsModule: string_tree
using ..EvaluateModule: eval_tree_array
using ..EvaluateDerivativeModule: eval_grad_tree_array
Expand Down Expand Up @@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression)
end

## optional
function _check_copy_into!(ex::AbstractExpression)
container = allocate_container(ex)
prealloc_ex = copy_into!(container, ex)
return container !== nothing && prealloc_ex == ex && prealloc_ex !== ex
end
function _check_count_nodes(ex::AbstractExpression)
return count_nodes(ex) isa Int64
end
Expand Down Expand Up @@ -156,6 +161,7 @@ ei_components = (
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
),
optional = (
copy_into! = "copies an expression into a preallocated container" => _check_copy_into!,
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,
count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes,
count_depth = "calculates the depth of the expression tree" => _check_count_depth,
Expand Down Expand Up @@ -260,14 +266,19 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode)
end

## optional
function _check_copy_into!(tree::AbstractExpressionNode)
container = allocate_container(tree)
prealloc_tree = copy_into!(container, tree)
return container !== nothing && prealloc_tree == tree && prealloc_tree !== container
end
function _check_leaf_copy(tree::AbstractExpressionNode)
tree.degree != 0 && return true
return leaf_copy(tree) isa typeof(tree)
end
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T}
tree.degree != 0 && return true
new_leaf = constructorof(typeof(tree))(; val=zero(T))
ret = leaf_copy!(new_leaf, tree)
ret = leaf_copy_into!(new_leaf, tree)
return new_leaf == tree && ret === new_leaf
end
function _check_leaf_convert(tree::AbstractExpressionNode)
Expand All @@ -292,16 +303,16 @@ function _check_branch_copy(tree::AbstractExpressionNode)
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
end
end
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
if tree.degree == 0
return true
end
new_branch = constructorof(typeof(tree))(; val=zero(T))
if tree.degree == 1
ret = branch_copy!(new_branch, tree, copy(tree.l))
ret = branch_copy_into!(new_branch, tree, copy(tree.l))
return new_branch == tree && ret === new_branch
else
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r))
return new_branch == tree && ret === new_branch
end
end
Expand Down Expand Up @@ -372,13 +383,14 @@ ni_components = (
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
),
optional = (
copy_into! = "copies a node into a preallocated container" => _check_copy_into!,
leaf_copy = "copies a leaf node" => _check_leaf_copy,
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!,
leaf_convert = "converts a leaf node" => _check_leaf_convert,
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
branch_copy = "copies a branch node" => _check_branch_copy,
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!,
branch_convert = "converts a branch node" => _check_branch_convert,
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
Expand Down Expand Up @@ -419,7 +431,7 @@ ni_description = (
[Arguments()]
)
@implements(
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
NodeInterface{all_ni_methods_except(())},
GraphNode,
[Arguments()]
)
Expand Down
11 changes: 0 additions & 11 deletions src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,12 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {
return GraphNode{promote_type(T1, T2)}
end

# TODO: Verify using this helps with garbage collection
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N()

"""
set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}
Set every field of `tree` equal to the corresponding field of `new_tree`.
"""
function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode)
# First, ensure we free some memory:
if new_tree.degree < 2 && tree.degree == 2
tree.r = create_dummy_node(typeof(tree))
end
if new_tree.degree < 1 && tree.degree >= 1
tree.l = create_dummy_node(typeof(tree))
end

tree.degree = new_tree.degree
if new_tree.degree == 0
tree.constant = new_tree.constant
Expand Down
69 changes: 69 additions & 0 deletions src/NodePreallocation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module NodePreallocationModule

using ..NodeModule:
AbstractExpressionNode,
with_type_parameters,
tree_mapreduce,
leaf_copy,
branch_copy,
set_node!

"""
allocate_container(prototype::AbstractExpressionNode, n=nothing)
Preallocate an array of `n` empty nodes matching the type of `prototype`.
If `n` is not provided, it will be computed from `length(prototype)`.
A given return value of this will be passed to `copy_into!` as the first argument,
so it should be compatible.
"""
function allocate_container(
prototype::N, n::Union{Nothing,Integer}=nothing
) where {T,N<:AbstractExpressionNode{T}}
num_nodes = @something(n, length(prototype))
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
end

"""
copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode}
Copy a node, recursively copying all children nodes, in-place to a preallocated container.
This should result in no extra allocations.
"""
function copy_into!(
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
) where {N<:AbstractExpressionNode}
_ref = if ref === nothing
Ref(0)
else
ref.x = 0
ref
end
return tree_mapreduce(
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
identity,
((p, c::Vararg{Any,M}) where {M}) ->
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
src,
N,
)
end
# COV_EXCL_START
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
set_node!(dest, src)
return dest
end
# COV_EXCL_STOP
function branch_copy_into!(
dest::N, src::N, children::Vararg{N,M}
) where {N<:AbstractExpressionNode,M}
dest.degree = M
dest.op = src.op
dest.l = children[1]
if M == 2
dest.r = children[2]
end
return dest
end

end
62 changes: 46 additions & 16 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk

using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
using ..ExpressionModule: AbstractExpression, Metadata
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
using ..ChainRulesModule: NodeTangent

import ..NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal,
branch_copy!
set_node!
import ..NodePreallocationModule: copy_into!, allocate_container
import ..NodeUtilsModule:
count_constant_nodes,
index_constant_nodes,
Expand Down Expand Up @@ -124,21 +124,29 @@ function leaf_copy(t::ParametricNode{T}) where {T}
return n
end
end
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
dest.degree = 0
if src.constant
dest.constant = true
dest.val = src.val
elseif !src.is_parameter
dest.constant = false
dest.is_parameter = false
dest.feature = src.feature
function set_node!(tree::ParametricNode, new_tree::ParametricNode)
tree.degree = new_tree.degree
if new_tree.degree == 0
if new_tree.constant
tree.constant = true
tree.val = new_tree.val
elseif !new_tree.is_parameter
tree.constant = false
tree.is_parameter = false
tree.feature = new_tree.feature
else
tree.constant = false
tree.is_parameter = true
tree.parameter = new_tree.parameter
end
else
dest.constant = false
dest.is_parameter = true
dest.parameter = src.parameter
tree.op = new_tree.op
tree.l = new_tree.l
if new_tree.degree == 2
tree.r = new_tree.r
end
end
return dest
return nothing
end
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
if t.constant
Expand Down Expand Up @@ -444,6 +452,28 @@ end
return node_type(; val=ex)
end
end
function allocate_container(
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
)
return (;
tree=allocate_container(get_contents(prototype), n),
parameters=similar(get_metadata(prototype).parameters),
)
end
function copy_into!(dest::NamedTuple, src::ParametricExpression)
new_tree = copy_into!(dest.tree, get_contents(src))
metadata = get_metadata(src)
new_parameters = dest.parameters
new_parameters .= metadata.parameters
new_metadata = Metadata((;
operators=metadata.operators,
variable_names=metadata.variable_names,
parameters=new_parameters,
parameter_names=metadata.parameter_names,
))
# TODO: Better interface for this^
return with_metadata(with_contents(src, new_tree), new_metadata)
end
###############################################################################

end
14 changes: 14 additions & 0 deletions src/StructuredExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type
using ..ChainRulesModule: NodeTangent

import ..NodeModule: constructorof
import ..NodePreallocationModule: copy_into!, allocate_container
import ..ExpressionModule:
get_contents,
get_metadata,
get_tree,
get_operators,
get_variable_names,
with_contents,
Metadata,
_copy,
_data,
Expand Down Expand Up @@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
return e
end

function allocate_container(
e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing
)
ts = get_contents(e)
return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts))))
end
function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression)
ts = get_contents(src)
new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts)))
return with_contents(src, new_contents)
end

end
Loading

0 comments on commit 46eb0de

Please sign in to comment.