Skip to content

Commit

Permalink
Fix helper function generators
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Oct 22, 2022
1 parent 061c198 commit ebe61f9
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ module OperatorEnumConstructionModule

import Zygote: gradient
import ..UtilsModule: max_ops
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: string_tree, Node
import ..EvaluateEquationModule: eval_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array

function create_evaluation_helper_functions(operators::OperatorEnum)
function create_evaluation_helpers!(operators::OperatorEnum)
@eval begin
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
Expand Down Expand Up @@ -37,7 +37,7 @@ function create_evaluation_helper_functions(operators::OperatorEnum)
end
end

function create_evaluation_helper_functions(operators::GenericOperatorEnum)
function create_evaluation_helpers!(operators::GenericOperatorEnum)
@eval begin
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
Expand All @@ -54,11 +54,14 @@ function create_evaluation_helper_functions(operators::GenericOperatorEnum)
end
end

function create_node_helper_functions(
function create_construction_helpers!(
operators::AbstractOperatorEnum; extend_user_operators::Bool=false
)
for (op, f) in enumerate(map(Symbol, binary_operators))
if typeof(operators) <: OperatorEnum
is_scalar_operator_enum = typeof(operators) <: OperatorEnum
type_requirements = is_scalar_operator_enum ? Real : Any

for (op, f) in enumerate(map(Symbol, operators.binops))
if is_scalar_operator_enum
f = if f in [:pow, :safe_pow]
Symbol(^)
else
Expand All @@ -74,7 +77,9 @@ function create_node_helper_functions(
Base.MainInclude.eval(
quote
import DynamicExpressions: Node
function $f(l::Node{T1}, r::Node{T2}) where {T1<:Real,T2<:Real}
function $f(
l::Node{T1}, r::Node{T2}
) where {T1<:$type_requirements,T2<:$type_requirements}
T = promote_type(T1, T2)
l = convert(Node{T}, l)
r = convert(Node{T}, r)
Expand All @@ -84,7 +89,9 @@ function create_node_helper_functions(
return Node($op, l, r)
end
end
function $f(l::Node{T1}, r::T2) where {T1<:Real,T2<:Real}
function $f(
l::Node{T1}, r::T2
) where {T1<:$type_requirements,T2<:$type_requirements}
T = promote_type(T1, T2)
l = convert(Node{T}, l)
r = convert(T, r)
Expand All @@ -94,7 +101,9 @@ function create_node_helper_functions(
Node($op, l, Node(; val=r))
end
end
function $f(l::T1, r::Node{T2}) where {T1<:Real,T2<:Real}
function $f(
l::T1, r::Node{T2}
) where {T1<:$type_requirements,T2<:$type_requirements}
T = promote_type(T1, T2)
l = convert(T, l)
r = convert(Node{T}, r)
Expand All @@ -108,7 +117,7 @@ function create_node_helper_functions(
)
end
# Redefine Base operations:
for (op, f) in enumerate(map(Symbol, unary_operators))
for (op, f) in enumerate(map(Symbol, operators.unaops))
if isdefined(Base, f)
f = :(Base.$(f))
elseif !extend_user_operators
Expand All @@ -118,7 +127,7 @@ function create_node_helper_functions(
Base.MainInclude.eval(
quote
import DynamicExpressions: Node
function $f(l::Node{T})::Node{T} where {T<:Real}
function $f(l::Node{T})::Node{T} where {T<:$type_requirements}
return l.constant ? Node(; val=$f(l.val)) : Node($op, l)
end
end,
Expand Down Expand Up @@ -209,8 +218,8 @@ function OperatorEnum(;
)

if define_helper_functions
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
create_evaluation_helper_functions(operators)
create_construction_helpers!(operators; extend_user_operators=extend_user_operators)
create_evaluation_helpers!(operators)
end

return operators
Expand Down Expand Up @@ -249,8 +258,8 @@ function GenericOperatorEnum(;
operators = GenericOperatorEnum(binary_operators, unary_operators)

if define_helper_functions
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
create_evaluation_helper_functions(operators)
create_construction_helpers!(operators; extend_user_operators=extend_user_operators)
create_evaluation_helpers!(operators)
end

return operators
Expand Down

0 comments on commit ebe61f9

Please sign in to comment.