Skip to content

Commit

Permalink
Merge branch 'master' into gd/adtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Jun 3, 2024
2 parents df51d34 + be17457 commit 07fdabd
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 92 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <[email protected]>"]
version = "5.29.1"
version = "5.30.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -39,6 +39,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicLimits = "19f23fe9-fdab-4a78-91af-e7b7767979c3"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Expand Down Expand Up @@ -73,6 +74,7 @@ LogExpFunctions = "0.3"
LuxCore = "0.1.11"
MacroTools = "0.5"
NaNMath = "1"
PreallocationTools = "0.4"
PrecompileTools = "1"
RecipesBase = "1.1"
Reexport = "1"
Expand All @@ -83,9 +85,10 @@ SciMLBase = "2"
Setfield = "1"
SpecialFunctions = "2"
StaticArrays = "1.1"
SymPy = "2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.0"
SymbolicUtils = "1.7"
SymbolicUtils = "2.0.2"
julia = "1.10"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ Latexify = "0.15, 0.16"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
StaticArrays = "1.5"
SymbolicUtils = "1"
SymbolicUtils = "2"
Symbolics = "5"
5 changes: 3 additions & 2 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ using PrecompileTools

import DomainSets: Domain

import SymbolicUtils: similarterm, iscall, operation, arguments, symtype, metadata
using TermInterface
import TermInterface: maketerm, iscall, operation, arguments, symtype, metadata

import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic,
FnType, @rule, Rewriters, substitute,
Expand All @@ -34,7 +35,7 @@ using PrecompileTools
import ArrayInterface
using RuntimeGeneratedFunctions
using SciMLBase, IfElse
using MacroTools
import MacroTools

using SymbolicIndexingInterface

Expand Down
43 changes: 11 additions & 32 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,10 @@ end

ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T}

function SymbolicUtils.similarterm(t::ArrayOp, f, args, _symtype = nothing; metadata = nothing)
oldargs = arguments(t)
if _symtype === nothing
_symtype = symtype(t)
end

if !all(isequal.(args, oldargs)) || !isequal(f, operation(t))
term = similarterm(t.term, f, args)
subs = Dict()
for (orig, new) in zip(oldargs, args)
isequal(orig, new) && continue
subs[orig] = new
end
if !isequal(f, operation(t))
subs[operation(t)] = f
end
expr = substitute(t.expr, subs)
expr = SymbolicUtils.term(operation(expr), arguments(expr)...)
else
term = t.term
expr = t.expr
end
if _symtype === nothing
_symtype = symtype(t)
end
return ArrayOp{_symtype}(t.output_idx, expr, t.reduce, term, t.shape, t.ranges, metadata)
function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, _symtype, m)
t = f(args...)
t isa Symbolic && !isnothing(metadata) ?
metadata(t, m) : t
end

shape(aop::ArrayOp) = aop.shape
Expand Down Expand Up @@ -147,6 +125,7 @@ macro arrayop(output_idx, expr, options...)
call = nothing

extra = []
isexpr = MacroTools.isexpr
for o in options
if isexpr(o, :call) && o.args[1] == :in
push!(rs, :($(o.args[2]) => $(o.args[3])))
Expand Down Expand Up @@ -611,7 +590,7 @@ function replace_by_scalarizing(ex, dict)

simterm = (x, f, args; kws...) -> begin
if metadata(x) !== nothing
similarterm(x, f, args; metadata=metadata(x))
maketerm(typeof(x), f, args, symtype(x), metadata(x))
else
f(args...)
end
Expand All @@ -622,7 +601,7 @@ function replace_by_scalarizing(ex, dict)
f = operation(x)
ff = replace_by_scalarizing(f, dict)
if metadata(x) !== nothing
similarterm(x, ff, arguments(x); metadata=metadata(x))
maketerm(typeof(x), ff, arguments(x), symtype(x), metadata(x))
else
ff(arguments(x)...)
end
Expand All @@ -636,11 +615,11 @@ function replace_by_scalarizing(ex, dict)
ex, simterm)
end

function prewalk_if(cond, f, t, similarterm)
function prewalk_if(cond, f, t, maketerm)
t′ = cond(t) ? f(t) : return t
if iscall(t′)
return similarterm(t′, operation(t′),
map(x->prewalk_if(cond, f, x, similarterm), arguments(t′)))
return maketerm(typeof(t′), TermInterface.head(t′),
map(x->prewalk_if(cond, f, x, maketerm), children(t′)))
else
return t′
end
Expand Down Expand Up @@ -773,7 +752,7 @@ function scalarize(arr)
elseif arr isa Num
wrap(scalarize(unwrap(arr)))
elseif iscall(arr) && symtype(arr) <: Number
t = similarterm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr))
t = maketerm(typeof(arr), operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata(arr))
iscall(t) ? scalarize_op(operation(t), t) : t
else
arr
Expand Down
4 changes: 2 additions & 2 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ operation(a::ComplexTerm{T}) where T = Complex{T}
arguments(a::ComplexTerm) = [a.re, a.im]
metadata(a::ComplexTerm) = metadata(a.re)

function similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing)
function maketerm(T::Type{<:ComplexTerm}, f, args, symtype, metadata)
if f <: Complex
ComplexTerm{real(f)}(args...)
else
similarterm(first(args), f, args, symtype; metadata=metadata)
maketerm(typeof(first(args)), f, args, symtype, metadata)
end
end

Expand Down
10 changes: 4 additions & 6 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,6 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray)
J = Int[]


simterm(x, f, args; kw...) = similarterm(x, f, args, symtype(x); kw...)

# This rewriter notes down which u's appear in a
# given du (whose index is stored in the `i` Ref)

Expand All @@ -552,7 +550,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray)
nothing
end

r = Rewriters.Postwalk(r, similarterm=simterm)
r = Rewriters.Postwalk(r)

for ii = 1:length(du)
i[] = ii
Expand Down Expand Up @@ -633,7 +631,7 @@ end

isidx(x) = x isa TermCombination

basic_simterm(t, g, args; kws...) = Term{Any}(g, args)
basic_mkterm(t, g, args, _, m) = metadata(Term{Any}(g, args), m)

let
# we do this in a let block so that Revise works on the list of rules
Expand Down Expand Up @@ -661,7 +659,7 @@ let
end
end
@rule ~x::issym => 0]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm))
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm))

global hessian_sparsity

Expand Down Expand Up @@ -695,7 +693,7 @@ let
u = map(value, vars)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=basic_simterm)(expr)
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_mkterm)(expr)
lp = linearity_propagator(f)
S = _sparse(lp, length(u))
S = full ? S : tril(S)
Expand Down
8 changes: 5 additions & 3 deletions src/semipoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff)
isop(x, op) = iscall(x) && operation(x) === op
isop(op) = Base.Fix2(isop, op)

bareterm(x, f, args; kw...) = Term{symtype(x)}(f, args)
simpleterm(T, f, args, sT, m) = Term{sT}(f, args)

function mark_and_exponentiate(expr, vars)
# Step 1
Expand All @@ -153,7 +153,7 @@ function mark_and_exponentiate(expr, vars)
@rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs))
@rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, unsorted_arguments(~a))...)
@rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)]
expr′ = Postwalk(RestartedChain(rules), similarterm = bareterm)(expr′)
expr′ = Postwalk(RestartedChain(rules), maketerm = simpleterm)(expr′)
end

function semipolyform_terms(expr, vars)
Expand Down Expand Up @@ -424,7 +424,9 @@ function unwrap_sp(m::SemiMonomial)
end
function unwrap_sp(x)
x = unwrap(x)
iscall(x) ? similarterm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x
iscall(x) ? maketerm(typeof(x),
TermInterface.head(x), map(unwrap_sp,
TermInterface.children(x))) : x
end

function cautious_sum(nls)
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi
d_separator = 'ˍ'

if ds === nothing
return similarterm(O, operation(O), map(diff2term, arguments(O)), metadata = O_metadata isa Nothing ?
return maketerm(typeof(O), head(O), map(diff2term, children(O)),
symtype(O), O_metadata isa Nothing ?
metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...))
else
oldop = operation(O)
Expand All @@ -128,7 +129,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi
return Sym{symtype(O)}(Symbol(opname, d_separator, ds))
end
newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds)
return setname(similarterm(O, rename(oldop, newname), arguments(O), metadata = O_metadata isa Nothing ?
return setname(maketerm(typeof(O), rename(oldop, newname), children(O), symtype(O), O_metadata isa Nothing ?
metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)), newname)
end
end
Expand Down
12 changes: 6 additions & 6 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,11 @@ function fast_substitute(expr, subs; operator = Nothing)
end
canfold[] && return op(args...)
end
similarterm(expr,
maketerm(typeof(expr),
op,
args,
symtype(expr);
metadata = metadata(expr))
symtype(expr),
metadata(expr))
end
function fast_substitute(expr, pair::Pair; operator = Nothing)
a, b = pair
Expand All @@ -516,11 +516,11 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
end
canfold[] && return op(args...)
end
similarterm(expr,
maketerm(typeof(expr),
op,
args,
symtype(expr);
metadata = metadata(expr))
symtype(expr),
metadata(expr))
end

function getparent(x, val=_fail)
Expand Down
6 changes: 3 additions & 3 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ end
@test getmetadata(unwrap(v[1]), TestMetaT) == 4
end

@testset "similarterm" begin
@testset "maketerm" begin
@variables A[1:5, 1:5] B[1:5, 1:5]

T = unwrap(3A)
@test isequal(T, similarterm(T, operation(T), arguments(T)))
@test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), symtype(T), nothing))
T2 = unwrap(3B)
@test isequal(T2, similarterm(T, operation(T), [*, 3, unwrap(B)]))
@test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], symtype(T), nothing))
end

getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
Expand Down
68 changes: 35 additions & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,41 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a
@register_symbolic limit(a, N)::Integer

if GROUP == "All" || GROUP == "Core"
@safetestset "Struct Test" begin include("struct.jl") end
@safetestset "Macro Test" begin include("macro.jl") end
@safetestset "Arrays" begin include("arrays.jl") end
@safetestset "View-setting" begin include("stencils.jl") end
@safetestset "Complex" begin include("complex.jl") end
@safetestset "Semi-polynomial" begin include("semipoly.jl") end
@safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end
@safetestset "Differentiation Test" begin include("diff.jl") end
@safetestset "ADTypes Test" begin include("adtypes.jl") end
@safetestset "Difference Test" begin include("difference.jl") end
@safetestset "Degree Test" begin include("degree.jl") end
@safetestset "Coeff Test" begin include("coeff.jl") end
@safetestset "Parsing Test" begin include("parsing.jl") end
@safetestset "Is Linear or Affine Test" begin include("islinear_affine.jl") end
@safetestset "Linear Solver Test" begin include("linear_solver.jl") end
@safetestset "Algebraic Solver Test" begin include("solver.jl") end
@safetestset "Overloading Test" begin include("overloads.jl") end
@safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end
@safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
@safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end
@safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end
VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end
@safetestset "Latexify Test" begin include("latexify.jl") end
@safetestset "Domain Test" begin include("domains.jl") end
@safetestset "SymPy Test" begin include("sympy.jl") end
@safetestset "Inequality Test" begin include("inequality.jl") end
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
@testset begin
@safetestset "Struct Test" begin include("struct.jl") end
@safetestset "Macro Test" begin include("macro.jl") end
@safetestset "Arrays" begin include("arrays.jl") end
@safetestset "View-setting" begin include("stencils.jl") end
@safetestset "Complex" begin include("complex.jl") end
@safetestset "Semi-polynomial" begin include("semipoly.jl") end
@safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end
@safetestset "Differentiation Test" begin include("diff.jl") end
@safetestset "ADTypes Test" begin include("adtypes.jl") end
@safetestset "Difference Test" begin include("difference.jl") end
@safetestset "Degree Test" begin include("degree.jl") end
@safetestset "Coeff Test" begin include("coeff.jl") end
@safetestset "Parsing Test" begin include("parsing.jl") end
@safetestset "Is Linear or Affine Test" begin include("islinear_affine.jl") end
@safetestset "Linear Solver Test" begin include("linear_solver.jl") end
@safetestset "Algebraic Solver Test" begin include("solver.jl") end
@safetestset "Overloading Test" begin include("overloads.jl") end
@safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end
@safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
@safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end
@safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end
VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end
@safetestset "Latexify Test" begin include("latexify.jl") end
@safetestset "Domain Test" begin include("domains.jl") end
@safetestset "SymPy Test" begin include("sympy.jl") end
@safetestset "Inequality Test" begin include("inequality.jl") end
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end
end

if GROUP == "All" || GROUP == "GroebnerExt"
Expand Down

0 comments on commit 07fdabd

Please sign in to comment.