From f285780bf1e63e07c5d4ff1250e8d66234510957 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 6 Sep 2024 11:11:12 -0400 Subject: [PATCH 01/10] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ddec70195..71f9b6375 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "6.6.0" +version = "6.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From a5a56cc5c4d885bd9c7e4b0c206b611cf21d66ff Mon Sep 17 00:00:00 2001 From: n0rbed Date: Fri, 6 Sep 2024 12:41:11 +0300 Subject: [PATCH 02/10] simplified acos(0/1), asin(0/1) and anything + 0 --- src/solver/postprocess.jl | 41 +++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index ff72fdf3f..afb87c263 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -1,4 +1,3 @@ - # Alex: make sure `Num`s are not processed here as they'd break it. _postprocess_root(x) = x @@ -32,12 +31,12 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) !iscall(x) && return x x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...) + oper = operation(x) # sqrt(0), cbrt(0) => 0 # sqrt(1), cbrt(1) => 1 - if iscall(x) && - (operation(x) === sqrt || operation(x) === cbrt || operation(x) === ssqrt || - operation(x) === scbrt) + if (oper === sqrt || oper === cbrt || oper === ssqrt || + oper === scbrt) arg = arguments(x)[1] if isequal(arg, 0) || isequal(arg, 1) return arg @@ -45,17 +44,17 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (X)^0 => 1 - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 0) + if oper === (^) && isequal(arguments(x)[2], 0) return 1 end # (X)^1 => X - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 1) + if oper === (^) && isequal(arguments(x)[2], 1) return arguments(x)[1] end # sqrt((N / D)^2 * M) => N / D * sqrt(M) - if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt) + if (oper === sqrt || oper === ssqrt) function squarefree_decomp(x::Integer) square, squarefree = big(1), big(1) for (p, d) in collect(Primes.factor(abs(x))) @@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2)) - if iscall(x) && operation(x) === (^) + if oper === (^) arg1, arg2 = arguments(x) if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt) if arg2 isa Integer @@ -105,6 +104,32 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end + arg = arguments(x)[1] + if oper === acos + if arg === 0 + return Symbolics.term(/, pi, 2) + elseif arg === 1 + return 0 + end + elseif oper === asin + if arg === 0 + return 0 + elseif arg === 1 + return Symbolics.term(/, pi, 2) + end + end + + if oper === (+) + args = arguments(x) + for arg in args + if isequal(arg, 0) + after_removing = setdiff(args, arg) + isone(length(after_removing)) && return after_removing[1] + return Symbolics.term(+, after_removing) + end + end + end + return x end From d072dae9c7975bb4395a5c3d783abb202ed29084 Mon Sep 17 00:00:00 2001 From: n0rbed Date: Sat, 7 Sep 2024 02:29:16 +0300 Subject: [PATCH 03/10] 2pi*N -> pi*N --- src/solver/ia_main.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index c1f998a4f..8f83ca02b 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -123,7 +123,7 @@ function isolate(lhs, var; warns=true, conditions=[]) new_var = (@variables $new_var)[1] rhs = map( sol -> term(rev_oper[oper], sol) + - term(*, Base.MathConstants.pi, 2 * new_var), + term(*, Base.MathConstants.pi, new_var), rhs) @info string(new_var) * " ϵ" * " Ζ" From 318034edbda5db98847cdfa7caebee64e24343d4 Mon Sep 17 00:00:00 2001 From: n0rbed Date: Sat, 7 Sep 2024 07:28:30 +0300 Subject: [PATCH 04/10] acos, asin simplification rules --- src/solver/attract.jl | 6 ++---- src/solver/postprocess.jl | 35 ++++++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/solver/attract.jl b/src/solver/attract.jl index 027f85a99..6d03778e8 100644 --- a/src/solver/attract.jl +++ b/src/solver/attract.jl @@ -197,10 +197,8 @@ function attract_trig(lhs, var) r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x)) @acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2) @acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2) - @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 * - ~x)) - @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 * - ~x)) + @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2*~x)) + @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2*~x)) @acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2) @acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x)) @acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x)) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index afb87c263..4c2f5249d 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -104,21 +104,38 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end - arg = arguments(x)[1] + isnegone(x) = isequal(-1, expand(x)) + ishalf(x) = isequal(1/2, expand(x)) + isneghalf(x) = isequal(-1/2, expand(x)) + symiszero(x) = isequal(0, expand(x)) + symisone(x) = isequal(1, expand(x)) + acos_rules = [(@rule acos(~x::symiszero) => Symbolics.term(/, pi, 2)), + (@rule acos(~x::symisone) => 0), + (@rule acos(~x::isnegone) => Symbolics.term(*, pi)), + (@rule acos(~x::ishalf) => Symbolics.term(/, pi, 3)), + (@rule acos(~x::isneghalf) => Symbolics.term(/, Symbolics.term(*,2,pi), 3)) + ] + + asin_rules = [(@rule asin(~x::symiszero) => 0), + (@rule asin(~x::symisone) => Symbolics.term(/, pi, 2)), + (@rule asin(~x::isnegone) => -Symbolics.term(/, pi, 2)), + (@rule asin(~x::ishalf) => Symbolics.term(/, pi, 6)), + (@rule asin(~x::isneghalf) => Symbolics.term(/, Symbolics.term(*,-1,pi), 6)) + ] + if oper === acos - if arg === 0 - return Symbolics.term(/, pi, 2) - elseif arg === 1 - return 0 + for r in acos_rules + after_r = r(x) + !isnothing(after_r) && return after_r end elseif oper === asin - if arg === 0 - return 0 - elseif arg === 1 - return Symbolics.term(/, pi, 2) + for r in asin_rules + after_r = r(x) + !isnothing(after_r) && return after_r end end + if oper === (+) args = arguments(x) for arg in args From 14e473e37554f8996446e6c61841c26efe6199ec Mon Sep 17 00:00:00 2001 From: n0rbed Date: Sat, 7 Sep 2024 08:07:42 +0300 Subject: [PATCH 05/10] added complex type --- src/solver/solve_helpers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f2420969d..7496f65f6 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -78,7 +78,7 @@ function check_expr_validity(expr) valid_type = false if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} || - type_expr == Complex{Num} || type_expr == ComplexTerm{Real} + type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} valid_type = true end iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported" From e4474cc85b022e758a0296c3b680c692dc7e0287 Mon Sep 17 00:00:00 2001 From: n0rbed Date: Sun, 8 Sep 2024 03:44:29 +0300 Subject: [PATCH 06/10] alternative solution --- ext/SymbolicsGroebnerExt.jl | 4 ---- ext/SymbolicsNemoExt.jl | 6 ++++++ src/solver/postprocess.jl | 43 ++++++++++++++----------------------- 3 files changed, 22 insertions(+), 31 deletions(-) diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index ebf5174a5..66d069060 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -320,13 +320,9 @@ end # Helps with precompilation time PrecompileTools.@setup_workload begin @variables a b c x y z - equation1 = a*log(x)^b + c ~ 0 - equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3 simple_linear_equations = [x - y, y + 2z] equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z] PrecompileTools.@compile_workload begin - symbolic_solve(equation1, x) - symbolic_solve(equation_actually_polynomial) symbolic_solve(simple_linear_equations, [x, y], warns=false) symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false) end diff --git a/ext/SymbolicsNemoExt.jl b/ext/SymbolicsNemoExt.jl index 16fe2414e..9c9f31db1 100644 --- a/ext/SymbolicsNemoExt.jl +++ b/ext/SymbolicsNemoExt.jl @@ -61,7 +61,13 @@ end PrecompileTools.@setup_workload begin @variables a b c x y z expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) + equation1 = a*log(x)^b + c ~ 0 + equation_polynomial = 9^x + 3^x + 2 + exp_eq = 5*2^(x+1) + 7^(x+3) PrecompileTools.@compile_workload begin + symbolic_solve(equation1, x) + symbolic_solve(equation_polynomial, x) + symbolic_solve(exp_eq) symbolic_solve(expr_with_params, x, dropmultiplicity=false) symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) end diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 4c2f5249d..2a0ce1c85 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -104,38 +104,27 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end - isnegone(x) = isequal(-1, expand(x)) - ishalf(x) = isequal(1/2, expand(x)) - isneghalf(x) = isequal(-1/2, expand(x)) - symiszero(x) = isequal(0, expand(x)) - symisone(x) = isequal(1, expand(x)) - acos_rules = [(@rule acos(~x::symiszero) => Symbolics.term(/, pi, 2)), - (@rule acos(~x::symisone) => 0), - (@rule acos(~x::isnegone) => Symbolics.term(*, pi)), - (@rule acos(~x::ishalf) => Symbolics.term(/, pi, 3)), - (@rule acos(~x::isneghalf) => Symbolics.term(/, Symbolics.term(*,2,pi), 3)) + opers = [acos, asin, atan] + exacts = [0, Symbolics.term(*, pi), Symbolics.term(/,pi,3), + Symbolics.term(/, pi, 2), + Symbolics.term(/, Symbolics.term(*, 2, pi), 3), + Symbolics.term(/, pi, 6), + Symbolics.term(/, Symbolics.term(*, 5, pi), 6), + Symbolics.term(/, pi, 4) ] - asin_rules = [(@rule asin(~x::symiszero) => 0), - (@rule asin(~x::symisone) => Symbolics.term(/, pi, 2)), - (@rule asin(~x::isnegone) => -Symbolics.term(/, pi, 2)), - (@rule asin(~x::ishalf) => Symbolics.term(/, pi, 6)), - (@rule asin(~x::isneghalf) => Symbolics.term(/, Symbolics.term(*,-1,pi), 6)) - ] - - if oper === acos - for r in acos_rules - after_r = r(x) - !isnothing(after_r) && return after_r - end - elseif oper === asin - for r in asin_rules - after_r = r(x) - !isnothing(after_r) && return after_r + if any(isequal(oper, o) for o in opers) && isempty(Symbolics.get_variables(x)) + val = eval(Symbolics.toexpr(x)) + for i in eachindex(exacts) + exact_val = eval(Symbolics.toexpr(exacts[i])) + if isapprox(exact_val, val, atol=1e-6) + return exacts[i] + elseif isapprox(-exact_val, val, atol=1e-6) + return -exacts[i] + end end end - if oper === (+) args = arguments(x) for arg in args From b9cd1f51461790d037a658277f325d73a96d076e Mon Sep 17 00:00:00 2001 From: n0rbed Date: Sun, 8 Sep 2024 09:11:37 +0300 Subject: [PATCH 07/10] function --- src/solver/postprocess.jl | 52 ++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 2a0ce1c85..6236573f9 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -104,26 +104,8 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end - opers = [acos, asin, atan] - exacts = [0, Symbolics.term(*, pi), Symbolics.term(/,pi,3), - Symbolics.term(/, pi, 2), - Symbolics.term(/, Symbolics.term(*, 2, pi), 3), - Symbolics.term(/, pi, 6), - Symbolics.term(/, Symbolics.term(*, 5, pi), 6), - Symbolics.term(/, pi, 4) - ] - - if any(isequal(oper, o) for o in opers) && isempty(Symbolics.get_variables(x)) - val = eval(Symbolics.toexpr(x)) - for i in eachindex(exacts) - exact_val = eval(Symbolics.toexpr(exacts[i])) - if isapprox(exact_val, val, atol=1e-6) - return exacts[i] - elseif isapprox(-exact_val, val, atol=1e-6) - return -exacts[i] - end - end - end + trig_simplified = check_trig_consts(x) + !isequal(trig_simplified, x) && return trig_simplified if oper === (+) args = arguments(x) @@ -153,3 +135,33 @@ function postprocess_root(x) end x # unreachable end + +function check_trig_consts(x) + !iscall(x) && return x + + oper = operation(x) + inv_opers = [asin, acos, atan] + inv_exacts = [0, Symbolics.term(*, pi), + Symbolics.term(/,pi,3), + Symbolics.term(/, pi, 2), + Symbolics.term(/, Symbolics.term(*, 2, pi), 3), + Symbolics.term(/, pi, 6), + Symbolics.term(/, Symbolics.term(*, 5, pi), 6), + Symbolics.term(/, pi, 4) + ] + + if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) + val = eval(Symbolics.toexpr(x)) + for i in eachindex(inv_exacts) + exact_val = eval(Symbolics.toexpr(inv_exacts[i])) + if isapprox(exact_val, val, atol=1e-6) + return inv_exacts[i] + elseif isapprox(-exact_val, val, atol=1e-6) + return -inv_exacts[i] + end + end + end + + # add [sin, cos, tan] simplifications in the future? + return x +end From d352d8a5c82a2cb7dccace1cb517328c4c942cbd Mon Sep 17 00:00:00 2001 From: n0rbed Date: Mon, 9 Sep 2024 03:26:16 +0300 Subject: [PATCH 08/10] eval -> symbolic_to_float --- src/solver/postprocess.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 6236573f9..741c36dc4 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -151,9 +151,9 @@ function check_trig_consts(x) ] if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) - val = eval(Symbolics.toexpr(x)) + val = Symbolics.symbolic_to_float(x) for i in eachindex(inv_exacts) - exact_val = eval(Symbolics.toexpr(inv_exacts[i])) + exact_val = Symbolics.symbolic_to_float(inv_exacts[i]) if isapprox(exact_val, val, atol=1e-6) return inv_exacts[i] elseif isapprox(-exact_val, val, atol=1e-6) From 4a5b616d543aa3d24824196bfdf8e6581131c788 Mon Sep 17 00:00:00 2001 From: n0rbed Date: Fri, 13 Sep 2024 13:04:07 +0300 Subject: [PATCH 09/10] docs and const arr --- src/solver/postprocess.jl | 48 +++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 741c36dc4..4dce04fc6 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -104,8 +104,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end - trig_simplified = check_trig_consts(x) - !isequal(trig_simplified, x) && return trig_simplified + x = convert_consts(x) if oper === (+) args = arguments(x) @@ -136,28 +135,49 @@ function postprocess_root(x) x # unreachable end -function check_trig_consts(x) - !iscall(x) && return x - oper = operation(x) - inv_opers = [asin, acos, atan] - inv_exacts = [0, Symbolics.term(*, pi), +inv_exacts = [0, Symbolics.term(*, pi), Symbolics.term(/,pi,3), Symbolics.term(/, pi, 2), Symbolics.term(/, Symbolics.term(*, 2, pi), 3), Symbolics.term(/, pi, 6), Symbolics.term(/, Symbolics.term(*, 5, pi), 6), Symbolics.term(/, pi, 4) - ] +] +inv_evald = Symbolics.symbolic_to_float.(inv_exacts) + +const inv_pairs = collect(zip(inv_exacts, inv_evald)) +""" + function convert_consts(x) +This function takes BasicSymbolic terms as input (x) and attempts +to simplify these basic symbolic terms using known values. +Currently, this function only supports inverse trignometric functions. + +## Examples +```jldoctest +julia> Symbolics.convert_consts(Symbolics.term(acos, 0)) +π / 2 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 0)) +0 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 1)) +π / 4 +``` +""" +function convert_consts(x) + !iscall(x) && return x + + oper = operation(x) + inv_opers = [asin, acos, atan] if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) val = Symbolics.symbolic_to_float(x) - for i in eachindex(inv_exacts) - exact_val = Symbolics.symbolic_to_float(inv_exacts[i]) - if isapprox(exact_val, val, atol=1e-6) - return inv_exacts[i] - elseif isapprox(-exact_val, val, atol=1e-6) - return -inv_exacts[i] + for (exact, evald) in inv_pairs + if isapprox(evald, val) + return exact + elseif isapprox(-evald, val) + return -exact end end end From eb664afd8207eba6952a270744e9716538aaab9e Mon Sep 17 00:00:00 2001 From: n0rbed Date: Fri, 13 Sep 2024 13:06:15 +0300 Subject: [PATCH 10/10] spelling mistake --- src/solver/postprocess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 4dce04fc6..4764690aa 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -151,7 +151,7 @@ const inv_pairs = collect(zip(inv_exacts, inv_evald)) function convert_consts(x) This function takes BasicSymbolic terms as input (x) and attempts to simplify these basic symbolic terms using known values. -Currently, this function only supports inverse trignometric functions. +Currently, this function only supports inverse trigonometric functions. ## Examples ```jldoctest