Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph-like expressions #56

Merged
merged 58 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
0251178
Clean up sharing tests
MilesCranmer Nov 25, 2023
ef9b690
Fix compat for Aqua
MilesCranmer Nov 25, 2023
d91f0cb
Expand `tree_mapreduce` to handle custom behaviour at shared nodes
MilesCranmer Nov 25, 2023
ea9a2fa
Add more `preserve_sharing` implementations
MilesCranmer Nov 25, 2023
b840190
Refactor string_tree
MilesCranmer Nov 25, 2023
b52cc6c
Strip brackets in unittests
MilesCranmer Nov 25, 2023
737c69e
Speed up `string_tree` using `Vector{Char}`
MilesCranmer Nov 26, 2023
22f9a45
Clean up formatting
MilesCranmer Nov 26, 2023
fb99bb4
Simplify NodeIndex implementation
MilesCranmer Nov 26, 2023
653dfc9
Try to free dangling nodes
MilesCranmer Nov 26, 2023
85349f8
Get `preserve_sharing` working with `simplify_tree!`
MilesCranmer Nov 26, 2023
c9ae9b5
Add `preserve_sharing` to `print_tree`
MilesCranmer Nov 27, 2023
7daa70c
Global setting for printing with sharing
MilesCranmer Nov 27, 2023
4e2cd26
Ensure variable names copied when set as global
MilesCranmer Nov 27, 2023
d71a123
Rename to `test_graphs.jl`
MilesCranmer Nov 27, 2023
700b6e0
Implement `GraphNode` type for shared edges
MilesCranmer Nov 27, 2023
1cad0ab
Fix some errors in `GraphNode` implementation
MilesCranmer Nov 27, 2023
418c40b
Fix for custom node types
MilesCranmer Nov 27, 2023
471edac
Bump version for milestone
MilesCranmer Nov 27, 2023
ba221a3
Get benchmark compatible with old shared nodes
MilesCranmer Nov 27, 2023
bb3b6d6
Expand tested functions in benchmark
MilesCranmer Nov 27, 2023
c1baad1
Pre-allocate storage for IdDict
MilesCranmer Nov 28, 2023
da644b9
Fix benchmark for earlier versions
MilesCranmer Nov 28, 2023
b1431e6
Fix macro test
MilesCranmer Nov 28, 2023
2bd986f
Fix benchmark for earlier versions
MilesCranmer Nov 28, 2023
6a116a9
Turn off sharing in count_depth calculation
MilesCranmer Dec 3, 2023
687a2b0
More tests for GraphNode
MilesCranmer Dec 3, 2023
1400752
Clean up node index
MilesCranmer Dec 3, 2023
d67334b
Reduce NodeIndex
MilesCranmer Dec 3, 2023
b92b3f1
Create `with_type_parameters` function
MilesCranmer Dec 4, 2023
191fa29
Add test of shared hashing
MilesCranmer Dec 17, 2023
a7e162b
Add break_sharing to more Base functions"
MilesCranmer Dec 17, 2023
d5b329e
Test additional base methods with break_sharing
MilesCranmer Dec 17, 2023
8cb9c8a
Formatting
MilesCranmer Dec 17, 2023
83f0a61
Force RecursiveArrayTools to version 2
MilesCranmer Dec 17, 2023
ea3d5bf
Test conversion as expected
MilesCranmer Dec 17, 2023
1487131
Test filter_map
MilesCranmer Dec 17, 2023
7860ae6
Formatting
MilesCranmer Dec 18, 2023
54f1245
Remove plots from benchmarks
MilesCranmer Dec 18, 2023
65e7594
Merge branch 'master' into MilesCranmer/issue14
MilesCranmer Dec 18, 2023
79b58ec
Revert "Force RecursiveArrayTools to version 2"
MilesCranmer Dec 18, 2023
f4ffa63
Align benchmarks across revisions
MilesCranmer Dec 18, 2023
40d3c1d
Test equality across types
MilesCranmer Dec 18, 2023
23ce949
Remove unused functions
MilesCranmer Dec 18, 2023
45b302f
Remove unused function
MilesCranmer Dec 18, 2023
9caa14b
Ensure `filter_map!` is specialized to functions
MilesCranmer Dec 18, 2023
07f434b
Make macro more robust against symbol overlap
MilesCranmer Dec 18, 2023
4180759
Type stability in equation utilities
MilesCranmer Dec 18, 2023
a2b4dcf
Add `break_sharing` to `copy`
MilesCranmer Dec 18, 2023
05a695c
Move `count_nodes` back to EquationUtils.jl
MilesCranmer Dec 18, 2023
3677a02
Make `set_constants!` lighter weight
MilesCranmer Dec 18, 2023
5bc2a2a
Move `count_nodes` back to `base.jl`
MilesCranmer Dec 19, 2023
f1af402
Overload common operators like +, -, *, so errors are more informative
MilesCranmer Dec 19, 2023
542413d
Update docs with GraphNode
MilesCranmer Dec 19, 2023
1103d14
Refactor constructors
MilesCranmer Dec 19, 2023
627091f
Fix docs rendering
MilesCranmer Dec 19, 2023
76655d3
Clean up type parameters in constructors
MilesCranmer Dec 19, 2023
f8f9678
Expand testing
MilesCranmer Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "0.13.1"
version = "0.14.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -24,6 +24,7 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"

[compat]
Aqua = "0.7"
Compat = "3.37, 4"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Expand Down
47 changes: 37 additions & 10 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant
using Zygote
if PACKAGE_VERSION < v"0.14.0"
@eval using DynamicExpressions: Node as GraphNode
else
@eval using DynamicExpressions: GraphNode
end

include("benchmark_utils.jl")

Expand Down Expand Up @@ -66,13 +71,15 @@ end

# These macros make the benchmarks work on older versions:
#! format: off
@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node}
@generated function _convert(::Type{N}, t; preserve_sharing) where {N}
PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t))
return :(convert(N, t; preserve_sharing=preserve_sharing))
PACKAGE_VERSION < v"0.14.0" && return :(convert(N, t; preserve_sharing=preserve_sharing))
return :(convert(N, t)) # Assume type used to infer sharing
end
@generated function _copy_node(t; preserve_sharing)
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
return :(copy_node(t; preserve_sharing=preserve_sharing))
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
return :(copy_node(t)) # Assume type used to infer sharing
end
@generated function get_set_constants!(tree)
!(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree)))
Expand Down Expand Up @@ -101,13 +108,36 @@ function benchmark_utilities()
:index_constants,
:string_tree,
)
has_both_modes = [:copy, :convert]
if PACKAGE_VERSION >= v"0.14.0"
append!(
has_both_modes,
[
:simplify_tree,
:count_nodes,
:count_constants,
:get_set_constants!,
:index_constants,
:string_tree,
],
)
end

operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
for func_k in all_funcs
suite[func_k] = let s = BenchmarkGroup()
for k in (:break_sharing, :preserve_sharing)
has_both_modes = func_k in (:copy, :convert)
k == :preserve_sharing && !has_both_modes && continue
for k in (
if func_k in has_both_modes
[:break_sharing, :preserve_sharing]
else
[:break_sharing]
end
)
preprocess = if k == :preserve_sharing && PACKAGE_VERSION >= v"0.14.0"
tree -> GraphNode(tree)
else
identity
end

f = if func_k == :copy
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
Expand All @@ -132,12 +162,9 @@ function benchmark_utilities()
setup=(
ntrees=100;
n=20;
trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees]
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
)
)
if !has_both_modes
s = s[k]
end
#! format: on
end
s
Expand Down
85 changes: 70 additions & 15 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,7 @@ Equations are specified as binary trees with the `Node` type, defined
as follows:

```@docs
Node{T}
```

There are a variety of constructors for `Node` objects, including:

```@docs
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
Node(op::Integer, l::Node)
Node(op::Integer, l::Node, r::Node)
Node(var_string::String)
Node
```

When you create an `Options` object, the operators
Expand All @@ -69,23 +60,87 @@ When using these node constructors, types will automatically be promoted.
You can convert the type of a node using `convert`:

```@docs
convert(::Type{Node{T1}}, tree::Node{T2}) where {T1, T2}
convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2}
```

You can set a `tree` (in-place) with `set_node!`:

```@docs
set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
set_node!
```

You can create a copy of a node with `copy_node`:

```@docs
copy_node(tree::Node)
copy_node
```

## Graph-Like Equations

You can describe an equation as a *graph* rather than a tree
by using the `GraphNode` type:

```@docs
GraphNode{T}
```

This makes it so you can have multiple parents for a given node,
and share parts of an expression. For example:

```julia
julia> operators = OperatorEnum(;
binary_operators=[+, -, *], unary_operators=[cos, sin, exp]
);

julia> x1, x2 = GraphNode(feature=1), GraphNode(feature=2)
(x1, x2)

julia> y = sin(x1) + 1.5
sin(x1) + 1.5

julia> z = exp(y) + y
exp(sin(x1) + 1.5) + {(sin(x1) + 1.5)}
```

Here, the curly braces `{}` indicate that the node
is shared by another (or more) parent node.

This means that we only need to change it once
to have changes propagate across the expression:

```julia
julia> y.r.val *= 0.9
1.35

julia> z
exp(sin(x1) + 1.35) + {(sin(x1) + 1.35)}
```

This also means there are fewer nodes to describe an expression:

```julia
julia> length(z)
6

julia> length(convert(Node, z))
10
```

where we have converted the `GraphNode` to a `Node` type,
which breaks shared connections into separate nodes.

## Abstract Types

Both the `Node` and `GraphNode` types are subtypes of the abstract type:

```@docs
AbstractExpressionNode{T}
```

There is also an abstract type `AbstractNode` which is a supertype of `Node`:
which can be used to create additional expression-like types.
The supertype of this abstract type is the `AbstractNode` type,
which is more generic but does not have all of the same methods:

```@docs
AbstractNode
AbstractNode{T}
```
57 changes: 33 additions & 24 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module DynamicExpressionsSymbolicUtilsExt

using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.EquationModule:
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
Expand All @@ -19,14 +20,17 @@ end
subs_bad(x) = isgood(x) ? x : Inf

function parse_tree_to_eqs(
tree::Node{T}, operators::AbstractOperatorEnum, index_functions::Bool=false
tree::AbstractExpressionNode{T},
operators::AbstractOperatorEnum,
index_functions::Bool=false,
) where {T}
if tree.degree == 0
# Return constant if needed
tree.constant && return subs_bad(tree.val::T)
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
end
# Collect the next children
# TODO: Type instability!
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
# Get the operation
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
Expand Down Expand Up @@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x
function split_eq(
op,
args,
operators::AbstractOperatorEnum;
operators::AbstractOperatorEnum,
::Type{N}=Node;
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, varMap, :split_eq)
!(op ∈ (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!"))
if Symbol(op) == Symbol(sum)
Expand All @@ -80,10 +85,10 @@ function split_eq(
else
ind = findoperation(op, operators.binops)
end
return Node(
return constructorof(N)(
ind,
convert(Node, args[1], operators; variable_names=variable_names),
convert(Node, op(args[2:end]...), operators; variable_names=variable_names),
convert(N, args[1], operators; variable_names=variable_names),
convert(N, op(args[2:end]...), operators; variable_names=variable_names),
)
end

Expand All @@ -96,7 +101,7 @@ end

function Base.convert(
::typeof(SymbolicUtils.Symbolic),
tree::Node,
tree::AbstractExpressionNode,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
Expand All @@ -109,20 +114,22 @@ function Base.convert(
)
end

function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
return Node(; val=DEFAULT_NODE_TYPE(x))
function Base.convert(
::Type{N}, x::Number, operators::AbstractOperatorEnum; kws...
) where {N<:AbstractExpressionNode}
return constructorof(N)(; val=DEFAULT_NODE_TYPE(x))
end

function Base.convert(
::typeof(Node),
::Type{N},
expr::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
)
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, nothing, :convert)
if !SymbolicUtils.istree(expr)
variable_names === nothing && return Node(String(expr.name))
return Node(String(expr.name), variable_names)
variable_names === nothing && return constructorof(N)(String(expr.name))
return constructorof(N)(String(expr.name), variable_names)
end

# First, we remove integer powers:
Expand All @@ -134,20 +141,21 @@ function Base.convert(
op = convert_to_function(SymbolicUtils.operation(expr), operators)
args = SymbolicUtils.arguments(expr)

length(args) > 2 && return split_eq(op, args, operators; variable_names=variable_names)
length(args) > 2 &&
return split_eq(op, args, operators, N; variable_names=variable_names)
ind = if length(args) == 2
findoperation(op, operators.binops)
else
findoperation(op, operators.unaops)
end

return Node(
ind, map(x -> convert(Node, x, operators; variable_names=variable_names), args)...
return constructorof(N)(
ind, map(x -> convert(N, x, operators; variable_names=variable_names), args)...
)
end

"""
node_to_symbolic(tree::Node, operators::AbstractOperatorEnum;
node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum;
variable_names::Union{Array{String, 1}, Nothing}=nothing,
index_functions::Bool=false)

Expand All @@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format.

## Arguments

- `tree::Node`: The equation to convert.
- `tree::AbstractExpressionNode`: The equation to convert.
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
each feature. Default is [x1, x2, x3, ...].
- `index_functions::Bool=false`: Whether to generate special names for the
operators, which then allows one to convert back to a `Node` format
operators, which then allows one to convert back to a `AbstractExpressionNode` format
using `symbolic_to_node`.
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
"""
function node_to_symbolic(
tree::Node,
tree::AbstractExpressionNode,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
Expand All @@ -192,13 +200,14 @@ end

function symbolic_to_node(
eqn::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
operators::AbstractOperatorEnum,
::Type{N}=Node;
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)::Node
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node)
return convert(Node, eqn, operators; variable_names=variable_names)
return convert(N, eqn, operators; variable_names=variable_names)
end

function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
Expand Down
5 changes: 4 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
AbstractNode,
AbstractExpressionNode,
GraphNode,
Node,
string_tree,
print_tree,
copy_node,
set_node!,
tree_mapreduce,
filter_map
import .EquationModule: constructorof, preserve_sharing
@reexport import .EquationUtilsModule:
count_nodes,
count_constants,
Expand All @@ -38,7 +41,7 @@ import Reexport: @reexport
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateEquationDerivativeModule:
eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

Expand Down
Loading
Loading