diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index ebf5174a5..66d069060 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -320,13 +320,9 @@ end # Helps with precompilation time PrecompileTools.@setup_workload begin @variables a b c x y z - equation1 = a*log(x)^b + c ~ 0 - equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3 simple_linear_equations = [x - y, y + 2z] equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z] PrecompileTools.@compile_workload begin - symbolic_solve(equation1, x) - symbolic_solve(equation_actually_polynomial) symbolic_solve(simple_linear_equations, [x, y], warns=false) symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false) end diff --git a/ext/SymbolicsNemoExt.jl b/ext/SymbolicsNemoExt.jl index 16fe2414e..9c9f31db1 100644 --- a/ext/SymbolicsNemoExt.jl +++ b/ext/SymbolicsNemoExt.jl @@ -61,7 +61,13 @@ end PrecompileTools.@setup_workload begin @variables a b c x y z expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) + equation1 = a*log(x)^b + c ~ 0 + equation_polynomial = 9^x + 3^x + 2 + exp_eq = 5*2^(x+1) + 7^(x+3) PrecompileTools.@compile_workload begin + symbolic_solve(equation1, x) + symbolic_solve(equation_polynomial, x) + symbolic_solve(exp_eq) symbolic_solve(expr_with_params, x, dropmultiplicity=false) symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) end diff --git a/src/solver/attract.jl b/src/solver/attract.jl index 027f85a99..6d03778e8 100644 --- a/src/solver/attract.jl +++ b/src/solver/attract.jl @@ -197,10 +197,8 @@ function attract_trig(lhs, var) r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x)) @acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2) @acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2) - @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 * - ~x)) - @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 * - ~x)) + @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2*~x)) + @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2*~x)) @acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2) @acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x)) @acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x)) diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index c1f998a4f..8f83ca02b 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -123,7 +123,7 @@ function isolate(lhs, var; warns=true, conditions=[]) new_var = (@variables $new_var)[1] rhs = map( sol -> term(rev_oper[oper], sol) + - term(*, Base.MathConstants.pi, 2 * new_var), + term(*, Base.MathConstants.pi, new_var), rhs) @info string(new_var) * " ϵ" * " Ζ" diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index ff72fdf3f..4764690aa 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -1,4 +1,3 @@ - # Alex: make sure `Num`s are not processed here as they'd break it. _postprocess_root(x) = x @@ -32,12 +31,12 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) !iscall(x) && return x x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...) + oper = operation(x) # sqrt(0), cbrt(0) => 0 # sqrt(1), cbrt(1) => 1 - if iscall(x) && - (operation(x) === sqrt || operation(x) === cbrt || operation(x) === ssqrt || - operation(x) === scbrt) + if (oper === sqrt || oper === cbrt || oper === ssqrt || + oper === scbrt) arg = arguments(x)[1] if isequal(arg, 0) || isequal(arg, 1) return arg @@ -45,17 +44,17 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (X)^0 => 1 - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 0) + if oper === (^) && isequal(arguments(x)[2], 0) return 1 end # (X)^1 => X - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 1) + if oper === (^) && isequal(arguments(x)[2], 1) return arguments(x)[1] end # sqrt((N / D)^2 * M) => N / D * sqrt(M) - if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt) + if (oper === sqrt || oper === ssqrt) function squarefree_decomp(x::Integer) square, squarefree = big(1), big(1) for (p, d) in collect(Primes.factor(abs(x))) @@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2)) - if iscall(x) && operation(x) === (^) + if oper === (^) arg1, arg2 = arguments(x) if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt) if arg2 isa Integer @@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end + x = convert_consts(x) + + if oper === (+) + args = arguments(x) + for arg in args + if isequal(arg, 0) + after_removing = setdiff(args, arg) + isone(length(after_removing)) && return after_removing[1] + return Symbolics.term(+, after_removing) + end + end + end + return x end @@ -122,3 +134,54 @@ function postprocess_root(x) end x # unreachable end + + +inv_exacts = [0, Symbolics.term(*, pi), + Symbolics.term(/,pi,3), + Symbolics.term(/, pi, 2), + Symbolics.term(/, Symbolics.term(*, 2, pi), 3), + Symbolics.term(/, pi, 6), + Symbolics.term(/, Symbolics.term(*, 5, pi), 6), + Symbolics.term(/, pi, 4) +] +inv_evald = Symbolics.symbolic_to_float.(inv_exacts) + +const inv_pairs = collect(zip(inv_exacts, inv_evald)) +""" + function convert_consts(x) +This function takes BasicSymbolic terms as input (x) and attempts +to simplify these basic symbolic terms using known values. +Currently, this function only supports inverse trigonometric functions. + +## Examples +```jldoctest +julia> Symbolics.convert_consts(Symbolics.term(acos, 0)) +π / 2 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 0)) +0 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 1)) +π / 4 +``` +""" +function convert_consts(x) + !iscall(x) && return x + + oper = operation(x) + inv_opers = [asin, acos, atan] + + if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) + val = Symbolics.symbolic_to_float(x) + for (exact, evald) in inv_pairs + if isapprox(evald, val) + return exact + elseif isapprox(-evald, val) + return -exact + end + end + end + + # add [sin, cos, tan] simplifications in the future? + return x +end diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f2420969d..7496f65f6 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -78,7 +78,7 @@ function check_expr_validity(expr) valid_type = false if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} || - type_expr == Complex{Num} || type_expr == ComplexTerm{Real} + type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} valid_type = true end iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported"