diff --git a/Project.toml b/Project.toml index eac9803c..a15ec484 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "0.18.0" +version = "0.18.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 28c3b12a..5f2dd267 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -111,7 +111,7 @@ end function Base.convert( ::typeof(SymbolicUtils.Symbolic), tree::Union{AbstractExpression,AbstractExpressionNode}, - operators::AbstractOperatorEnum; + operators::Union{AbstractOperatorEnum,Nothing}=nothing; variable_names::Union{Array{String,1},Nothing}=nothing, index_functions::Bool=false, # Deprecated: @@ -119,7 +119,10 @@ function Base.convert( ) variable_names = deprecate_varmap(variable_names, varMap, :convert) return node_to_symbolic( - tree, operators; variable_names=variable_names, index_functions=index_functions + tree, + get_operators(tree, operators); + variable_names=variable_names, + index_functions=index_functions, ) end diff --git a/src/Expression.jl b/src/Expression.jl index 986b85de..9d1e0e3f 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -101,7 +101,9 @@ or `cur_operators` if it is not `nothing`. If left as default, it requires `cur_operators` to not be `nothing`. `cur_operators` would typically be an `OperatorEnum`. """ -function get_operators(ex::AbstractExpression, operators) +function get_operators( + ex::AbstractExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) return error("`get_operators` function must be implemented for $(typeof(ex)) types.") end @@ -110,7 +112,10 @@ end The same as `operators`, but for variable names. """ -function get_variable_names(ex::AbstractExpression, variable_names) +function get_variable_names( + ex::AbstractExpression, + variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, +) return error( "`get_variable_names` function must be implemented for $(typeof(ex)) types." ) @@ -179,10 +184,23 @@ function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T return preserve_sharing(N) end -function get_operators(ex::Expression, operators=nothing) +function get_operators( + tree::AbstractExpressionNode, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) + if operators === nothing + throw(ArgumentError("`operators` must be provided for $(typeof(tree)) types.")) + else + return operators + end +end +function get_operators( + ex::Expression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) return operators === nothing ? ex.metadata.operators : operators end -function get_variable_names(ex::Expression, variable_names=nothing) +function get_variable_names( + ex::Expression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing +) return variable_names === nothing ? ex.metadata.variable_names : variable_names end function get_tree(ex::Expression) @@ -249,7 +267,10 @@ end import ..StringsModule: string_tree, print_tree function string_tree( - ex::AbstractExpression, operators=nothing; variable_names=nothing, kws... + ex::AbstractExpression, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + variable_names=nothing, + kws..., ) return string_tree( get_tree(ex), @@ -260,7 +281,11 @@ function string_tree( end for io in ((), (:(io::IO),)) @eval function print_tree( - $(io...), ex::AbstractExpression, operators=nothing; variable_names=nothing, kws... + $(io...), + ex::AbstractExpression, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + variable_names=nothing, + kws..., ) return println($(io...), string_tree(ex, operators; variable_names, kws...)) end @@ -283,7 +308,9 @@ function max_feature(ex::AbstractExpression) ) end -function _validate_input(ex::AbstractExpression, X, operators) +function _validate_input( + ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing} +) if get_operators(ex, operators) isa OperatorEnum @assert X isa AbstractMatrix @assert max_feature(ex) <= size(X, 1) @@ -292,7 +319,10 @@ function _validate_input(ex::AbstractExpression, X, operators) end function eval_tree_array( - ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws... + ex::AbstractExpression, + cX::AbstractMatrix, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., ) _validate_input(ex, cX, operators) return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...) @@ -305,7 +335,10 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array # - differentiable_eval_tree_array function eval_grad_tree_array( - ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws... + ex::AbstractExpression, + cX::AbstractMatrix, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., ) _validate_input(ex, cX, operators) return eval_grad_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...) @@ -319,14 +352,16 @@ end function _grad_evaluator( ex::AbstractExpression, cX::AbstractMatrix, - operators=nothing; + operators::Union{AbstractOperatorEnum,Nothing}=nothing; variable=Val(true), kws..., ) _validate_input(ex, cX, operators) return _grad_evaluator(get_tree(ex), cX, get_operators(ex, operators); variable, kws...) end -function (ex::AbstractExpression)(X, operators=nothing; kws...) +function (ex::AbstractExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) _validate_input(ex, X, operators) return get_tree(ex)(X, get_operators(ex, operators); kws...) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index d4ecac10..09ce8aa3 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -2,6 +2,7 @@ module ParametricExpressionModule using DispatchDoctor: @stable, @unstable +using ..OperatorEnumModule: AbstractOperatorEnum using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce using ..ExpressionModule: AbstractExpression, Metadata @@ -70,7 +71,7 @@ struct ParametricExpression{ end function ParametricExpression( tree::ParametricNode{T1}; - operators, + operators::Union{AbstractOperatorEnum,Nothing}, variable_names, parameters::AbstractMatrix{T2}, parameter_names, @@ -141,10 +142,15 @@ end get_contents(ex::ParametricExpression) = ex.tree get_metadata(ex::ParametricExpression) = ex.metadata get_tree(ex::ParametricExpression) = ex.tree -function get_operators(ex::ParametricExpression, operators=nothing) +function get_operators( + ex::ParametricExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) return operators === nothing ? ex.metadata.operators : operators end -function get_variable_names(ex::ParametricExpression, variable_names=nothing) +function get_variable_names( + ex::ParametricExpression, + variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, +) return variable_names === nothing ? ex.metadata.variable_names : variable_names end @inline _copy_with_nothing(x) = copy(x) @@ -232,15 +238,18 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T} ) end #! format: off -function (ex::ParametricExpression)(X::AbstractMatrix, operators=nothing; kws...) +function (ex::ParametricExpression)(X::AbstractMatrix, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...) return eval_tree_array(ex, X, operators; kws...) # Will error end -function eval_tree_array(::ParametricExpression{T}, ::AbstractMatrix{T}, operators=nothing; kws...) where {T} +function eval_tree_array(::ParametricExpression{T}, ::AbstractMatrix{T}, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...) where {T} return error("Incorrect call. You must pass the `classes::Vector` argument when calling `eval_tree_array`.") end #! format: on function (ex::ParametricExpression)( - X::AbstractMatrix{T}, classes::AbstractVector{<:Integer}, operators=nothing; kws... + X::AbstractMatrix{T}, + classes::AbstractVector{<:Integer}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., ) where {T} (output, flag) = eval_tree_array(ex, X, classes, operators; kws...) # Will error if !flag @@ -252,7 +261,7 @@ function eval_tree_array( ex::ParametricExpression{T}, X::AbstractMatrix{T}, classes::AbstractVector{<:Integer}, - operators=nothing; + operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws..., ) where {T} @assert length(classes) == size(X, 2) @@ -270,7 +279,7 @@ function eval_tree_array( end function string_tree( ex::ParametricExpression, - operators=nothing; + operators::Union{AbstractOperatorEnum,Nothing}=nothing; variable_names=nothing, display_variable_names=nothing, X_sym_units=nothing, diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 75fede26..1fb9470d 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -226,9 +226,13 @@ end @testitem "Miscellaneous expression calls" begin using DynamicExpressions + using DynamicExpressions: get_tree, get_operators ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"]) @test DynamicExpressions.ExpressionModule.node_type(ex) <: Node @test !isempty(ex) + + tree = get_tree(ex) + @test_throws ArgumentError get_operators(tree, nothing) end diff --git a/test/test_multi_expression.jl b/test/test_multi_expression.jl index 56ee6d0e..e049aebd 100644 --- a/test/test_multi_expression.jl +++ b/test/test_multi_expression.jl @@ -87,10 +87,15 @@ )::Expression{T,N} return fused_expression.tree end - function DE.get_operators(ex::MultiScalarExpression, operators=nothing) + function DE.get_operators( + ex::MultiScalarExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing + ) return operators === nothing ? ex.metadata.operators : operators end - function DE.get_variable_names(ex::MultiScalarExpression, variable_names=nothing) + function DE.get_variable_names( + ex::MultiScalarExpression, + variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + ) return variable_names === nothing ? ex.metadata.variable_names : variable_names end function Base.copy(ex::MultiScalarExpression)