Skip to content

Commit

Permalink
Merge branch 'JuliaSymbolics:master' into trig_issue
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed authored Sep 8, 2024
2 parents a17a71f + 81d894b commit 8dd6f4b
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 102 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <[email protected]>"]
version = "6.7.0"
version = "6.11.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -44,15 +44,15 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsForwardDiffExt = "ForwardDiff"
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsLuxExt = "Lux"
SymbolicsNemoExt = "Nemo"
SymbolicsPreallocationToolsExt = ["PreallocationTools", "ForwardDiff"]
SymbolicsSymPyExt = "SymPy"
Expand All @@ -76,7 +76,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.11"
Lux = "1"
MacroTools = "0.5"
NaNMath = "1"
Nemo = "0.45, 0.46"
Expand Down
11 changes: 0 additions & 11 deletions ext/SymbolicsLuxCoreExt.jl

This file was deleted.

18 changes: 18 additions & 0 deletions ext/SymbolicsLuxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module SymbolicsLuxExt

using Lux
using Symbolics
using Lux.LuxCore
using Symbolics.SymbolicUtils

function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}})
Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x))
end

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng())
eltype = Real
end

end
19 changes: 0 additions & 19 deletions ext/SymbolicsNemoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,13 @@ function Symbolics.factor_use_nemo(poly::Num)
return sym_unit, sym_factors
end

# gcd(x^2 - y^2, x^3 - y^3) -> x - y
function Symbolics.gcd_use_nemo(poly1::Num, poly2::Num)
Symbolics.check_polynomial(poly1)
Symbolics.check_polynomial(poly2)
vars1 = Symbolics.get_variables(poly1)
vars2 = Symbolics.get_variables(poly2)
vars = vcat(vars1, vars2)
nemo_ring, nemo_vars = Nemo.polynomial_ring(Nemo.QQ, map(string, vars))
sym_to_nemo = Dict(vars .=> nemo_vars)
nemo_to_sym = Dict(v => k for (k, v) in sym_to_nemo)
nemo_poly1 = Symbolics.substitute(poly1, sym_to_nemo)
nemo_poly2 = Symbolics.substitute(poly2, sym_to_nemo)
nemo_gcd = Nemo.gcd(nemo_poly1, nemo_poly2)
sym_gcd = Symbolics.wrap(nemo_crude_evaluate(nemo_gcd, nemo_to_sym))
return sym_gcd
end


# Helps with precompilation time
PrecompileTools.@setup_workload begin
@variables a b c x y z
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
PrecompileTools.@compile_workload begin
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
symbolic_solve(x^10 - a^10, x, dropmultiplicity=false)
symbolic_solve([x^2 - a^2, x + a], x)
end
end

Expand Down
10 changes: 10 additions & 0 deletions src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ end
@register_symbolic Base.rand(x)
@register_symbolic Base.randn(x)

@register_symbolic Base.clamp(x, y, z)

function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1})
x, l, h = args
T = promote_type(symtype(x), symtype(l), symtype(h))
z = zero(T)
o = one(T)
ifelse(x<l, z, ifelse(x>h, z, o))
end

@register_symbolic Distributions.pdf(dist,x)
@register_symbolic Distributions.logpdf(dist,x)
@register_symbolic Distributions.cdf(dist,x)
Expand Down
71 changes: 15 additions & 56 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,31 +165,21 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
expr = [expr]
expr_univar = false
end
if !expr_univar && x_univar
x = [x]
x_univar = false
end

if x_univar
sols = []
if expr_univar
sols = check_poly_inunivar(expr, x) ?
solve_univar(expr, x, dropmultiplicity=dropmultiplicity) :
ia_solve(expr, x, warns=warns)
isequal(sols, nothing) && return nothing
else
for i in eachindex(expr)
if !check_poly_inunivar(expr[i], x)
warns && @warn("Solve can not solve this input currently")
return nothing
end
end
sols = solve_multipoly(
expr, x, dropmultiplicity=dropmultiplicity, warns=warns)
isequal(sols, nothing) && return nothing
end

sols = check_poly_inunivar(expr, x) ?
solve_univar(expr, x, dropmultiplicity = dropmultiplicity) :
ia_solve(expr, x, warns = warns)
isequal(sols, nothing) && return nothing
sols = map(postprocess_root, sols)
return sols
end

if !expr_univar && !x_univar
if !x_univar
for e in expr
for var in x
if !check_poly_inunivar(e, var)
Expand All @@ -201,11 +191,13 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where

sols = solve_multivar(expr, x, dropmultiplicity=dropmultiplicity, warns=warns)
isequal(sols, nothing) && return nothing
for sol in sols
sols = convert(Vector{Any}, sols)
for i in eachindex(sols)
for var in x
if haskey(sol, var)
sol[var] = postprocess_root(sol[var])
end
sols[i][var] = postprocess_root(sols[i][var])
end
if length(collect(keys(sols[i]))) == 1
sols[i] = collect(values(sols[i]))[1]
end
end

Expand Down Expand Up @@ -310,39 +302,6 @@ function solve_univar(expression, x; dropmultiplicity=true)
return arr_roots
end

# You can compute the GCD between a system of polynomials by doing the following:
# Get the GCD between the first two polys,
# and get the GCD between this result and the following index,
# say: solve([x^2 - 1, x - 1, (x-1)^20], x)
# the GCD between the first two terms is obviously x-1,
# now we call gcd_use_nemo() on this term, and the following,
# gcd_use_nemo(x - 1, (x-1)^20), which is again x-1.
# now we just need to solve(x-1, x) to get the common root in this
# system of equations.
function solve_multipoly(polys::Vector, x::Num; dropmultiplicity = true, warns = true)
polys = unique(polys)

if length(polys) < 1
warns && @warn("No expressions entered")
return nothing
end
if length(polys) == 1
return solve_univar(polys[1], x, dropmultiplicity = dropmultiplicity)
end

gcd = gcd_use_nemo(polys[1], polys[2])

for i in eachindex(polys)[3:end]
gcd = gcd_use_nemo(gcd, polys[i])
end

if isequal(gcd, 1)
return []
end

return solve_univar(gcd, x, dropmultiplicity = dropmultiplicity)
end

function solve_multivar(eqs::Any, vars::Any; dropmultiplicity = true, warns = true)
throw("Groebner bases engine is required. Execute `using Groebner` to enable this functionality.")
end
4 changes: 0 additions & 4 deletions src/solver/nemo_stuff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ function factor_use_nemo(poly::Any)
throw("Nemo is required. Execute `using Nemo` to enable this functionality.")
end

# gcd(x^2 - y^2, x^3 - y^3) -> x - y
function gcd_use_nemo(poly1::Any, poly2::Any)
throw("Nemo is required. Execute `using Nemo` to enable this functionality.")
end
3 changes: 3 additions & 0 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ x = Num.(randn(10))
@test norm(x, 1) == norm(Symbolics.value.(x), 1)
@test norm(x, 1.2) == norm(Symbolics.value.(x), 1.2)

@test clamp.(x, 0, 1) == clamp.(Symbolics.value.(x), 0, 1)
@test isequal(Symbolics.derivative(clamp(a, 0, 1), a), ifelse(a < 0, 0, ifelse(a>1, 0, 1)))

@variables x[1:2]
@test isequal(scalarize(norm(x)), sqrt(abs2(x[1]) + abs2(x[2])))
@test isequal(scalarize(norm(x, Inf)), max(abs(x[1]), abs(x[2])))
Expand Down
11 changes: 3 additions & 8 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ end

@testset "Multivar parametric" begin
@variables x y a
@test isequal(symbolic_solve([x + a, a - 1], x), [])
@test isequal(symbolic_solve([x + a, a - 1], x), [-1])
@test isequal(symbolic_solve([x - a, y + a], [x, y]), [Dict(y => -a, x => a)])
@test isequal(symbolic_solve([x*y - a, x*y + x], [x, y]), [Dict(y => -1, x => -a)])
@test isequal(symbolic_solve([x*y - a, 1 ~ 3], [x, y]), [])
Expand All @@ -283,13 +283,13 @@ end
@test isequal(sol, [Dict(t => -5 / (-1 + u + v), w => 1 - u - v)])

sol = symbolic_solve([x-y, y-z], [x])
@test isequal(sol, [Dict(x=>z)])
@test isequal(sol, [z])

sol = symbolic_solve([x-y, y-z], [x, y])
@test isequal(sol, [Dict(x=>z, y=>z)])

sol = symbolic_solve([x + y - z, y - z], [x])
@test isequal(sol, [Dict(x=>0)])
@test isequal(sol, [0])
end

@testset "Factorisation" begin
Expand All @@ -310,11 +310,6 @@ end
@test isequal(expand(u*prod(factors) - f), 0)
end

@testset "GCD" begin
f1, f2 = x^2 - y^2, x^3 - y^3
@test isequal(x - y, Symbolics.gcd_use_nemo(f1, f2))
end


# Post Process roots #
@testset "Post Process roots" begin
Expand Down

0 comments on commit 8dd6f4b

Please sign in to comment.