From 11838780448fd46ccb765240221e4f8a4c4a69a7 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 17 May 2024 10:01:53 -0400 Subject: [PATCH 01/18] revert SU lowerbound --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5e95d407d..44c7d3774 100644 --- a/Project.toml +++ b/Project.toml @@ -83,7 +83,7 @@ SpecialFunctions = "2" StaticArrays = "1.1" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.0" -SymbolicUtils = "1.7" +SymbolicUtils = "1.6" julia = "1.10" [extras] From 7aadb1219317adae5700c624adeb27dee929aefc Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 25 May 2024 00:22:07 +0000 Subject: [PATCH 02/18] CompatHelper: add new compat entry for SymPy in [weakdeps] at version 2, (keep existing compat) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 44c7d3774..1f61facc4 100644 --- a/Project.toml +++ b/Project.toml @@ -81,6 +81,7 @@ SciMLBase = "2" Setfield = "1" SpecialFunctions = "2" StaticArrays = "1.1" +SymPy = "2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.0" SymbolicUtils = "1.6" From c2ba2989c7b79e8dc4b65f5a626838304eab1a1b Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 25 May 2024 00:22:11 +0000 Subject: [PATCH 03/18] CompatHelper: add new compat entry for PreallocationTools in [weakdeps] at version 0.4, (keep existing compat) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 44c7d3774..99c503f6b 100644 --- a/Project.toml +++ b/Project.toml @@ -71,6 +71,7 @@ LogExpFunctions = "0.3" LuxCore = "0.1.11" MacroTools = "0.5" NaNMath = "1" +PreallocationTools = "0.4" PrecompileTools = "1" RecipesBase = "1.1" Reexport = "1" From e1a64224c800d0668e7aa27520b135d5d9e7f50a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 30 May 2024 15:20:32 +0200 Subject: [PATCH 04/18] Bump SymbolicUtils to v2 --- Project.toml | 2 +- src/Symbolics.jl | 2 +- src/arrays.jl | 16 ++++++++-------- src/complex.jl | 4 ++-- src/diff.jl | 8 ++++---- src/semipoly.jl | 4 ++-- src/utils.jl | 4 ++-- src/variable.jl | 4 ++-- test/arrays.jl | 6 +++--- 9 files changed, 25 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index 049947cea..7a9eaae37 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ StaticArrays = "1.1" SymPy = "2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.0" -SymbolicUtils = "1.6" +SymbolicUtils = "2" julia = "1.10" [extras] diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 8cf10baa3..d2350bdfd 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -19,7 +19,7 @@ using PrecompileTools import DomainSets: Domain - import SymbolicUtils: similarterm, iscall, operation, arguments, symtype, metadata + import SymbolicUtils: maketerm, iscall, operation, arguments, symtype, metadata import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, FnType, @rule, Rewriters, substitute, diff --git a/src/arrays.jl b/src/arrays.jl index 242cb05fe..cffa99aaf 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -62,14 +62,14 @@ end ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} -function SymbolicUtils.similarterm(t::ArrayOp, f, args, _symtype = nothing; metadata = nothing) +function SymbolicUtils.maketerm(t::ArrayOp, f, args, _symtype = nothing; metadata = nothing) oldargs = arguments(t) if _symtype === nothing _symtype = symtype(t) end if !all(isequal.(args, oldargs)) || !isequal(f, operation(t)) - term = similarterm(t.term, f, args) + term = maketerm(t.term, f, args) subs = Dict() for (orig, new) in zip(oldargs, args) isequal(orig, new) && continue @@ -611,7 +611,7 @@ function replace_by_scalarizing(ex, dict) simterm = (x, f, args; kws...) -> begin if metadata(x) !== nothing - similarterm(x, f, args; metadata=metadata(x)) + maketerm(x, f, args; metadata=metadata(x)) else f(args...) end @@ -622,7 +622,7 @@ function replace_by_scalarizing(ex, dict) f = operation(x) ff = replace_by_scalarizing(f, dict) if metadata(x) !== nothing - similarterm(x, ff, arguments(x); metadata=metadata(x)) + maketerm(x, ff, arguments(x); metadata=metadata(x)) else ff(arguments(x)...) end @@ -636,11 +636,11 @@ function replace_by_scalarizing(ex, dict) ex, simterm) end -function prewalk_if(cond, f, t, similarterm) +function prewalk_if(cond, f, t, maketerm) t′ = cond(t) ? f(t) : return t if iscall(t′) - return similarterm(t′, operation(t′), - map(x->prewalk_if(cond, f, x, similarterm), arguments(t′))) + return maketerm(t′, operation(t′), + map(x->prewalk_if(cond, f, x, maketerm), arguments(t′))) else return t′ end @@ -773,7 +773,7 @@ function scalarize(arr) elseif arr isa Num wrap(scalarize(unwrap(arr))) elseif iscall(arr) && symtype(arr) <: Number - t = similarterm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr)) + t = maketerm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr)) iscall(t) ? scalarize_op(operation(t), t) : t else arr diff --git a/src/complex.jl b/src/complex.jl index 27124d3b8..e94ff5e63 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -23,11 +23,11 @@ operation(a::ComplexTerm{T}) where T = Complex{T} arguments(a::ComplexTerm) = [a.re, a.im] metadata(a::ComplexTerm) = metadata(a.re) -function similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing) +function maketerm(t::ComplexTerm, f, args, symtype; metadata=nothing) if f <: Complex ComplexTerm{real(f)}(args...) else - similarterm(first(args), f, args, symtype; metadata=metadata) + maketerm(first(args), f, args, symtype; metadata=metadata) end end diff --git a/src/diff.jl b/src/diff.jl index aeb2cc78e..ca8e960e5 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -541,7 +541,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) J = Int[] - simterm(x, f, args; kw...) = similarterm(x, f, args, symtype(x); kw...) + simterm(x, f, args; kw...) = maketerm(x, f, args, symtype(x); kw...) # This rewriter notes down which u's appear in a # given du (whose index is stored in the `i` Ref) @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) nothing end - r = Rewriters.Postwalk(r, similarterm=simterm) + r = Rewriters.Postwalk(r, maketerm=simterm) for ii = 1:length(du) i[] = ii @@ -661,7 +661,7 @@ let end end @rule ~x::issym => 0] - linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm)) + linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_simterm)) global hessian_sparsity @@ -695,7 +695,7 @@ let u = map(value, vars) idx(i) = TermCombination(Set([Dict(i=>1)])) dict = Dict(u .=> idx.(1:length(u))) - f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=basic_simterm)(expr) + f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_simterm)(expr) lp = linearity_propagator(f) S = _sparse(lp, length(u)) S = full ? S : tril(S) diff --git a/src/semipoly.jl b/src/semipoly.jl index 46b635ba3..e4f02298f 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -153,7 +153,7 @@ function mark_and_exponentiate(expr, vars) @rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs)) @rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, unsorted_arguments(~a))...) @rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)] - expr′ = Postwalk(RestartedChain(rules), similarterm = bareterm)(expr′) + expr′ = Postwalk(RestartedChain(rules), maketerm = bareterm)(expr′) end function semipolyform_terms(expr, vars) @@ -424,7 +424,7 @@ function unwrap_sp(m::SemiMonomial) end function unwrap_sp(x) x = unwrap(x) - iscall(x) ? similarterm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x + iscall(x) ? maketerm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x end function cautious_sum(nls) diff --git a/src/utils.jl b/src/utils.jl index b6988f6ea..cd41ec377 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -113,7 +113,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi d_separator = 'ˍ' if ds === nothing - return similarterm(O, operation(O), map(diff2term, arguments(O)), metadata = O_metadata isa Nothing ? + return maketerm(O, operation(O), map(diff2term, arguments(O)), metadata = O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)) else oldop = operation(O) @@ -128,7 +128,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi return Sym{symtype(O)}(Symbol(opname, d_separator, ds)) end newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds) - return setname(similarterm(O, rename(oldop, newname), arguments(O), metadata = O_metadata isa Nothing ? + return setname(maketerm(O, rename(oldop, newname), arguments(O), metadata = O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)), newname) end end diff --git a/src/variable.jl b/src/variable.jl index 8722fbb04..8b07d2ffc 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -488,7 +488,7 @@ function fast_substitute(expr, subs; operator = Nothing) end canfold[] && return op(args...) end - similarterm(expr, + maketerm(expr, op, args, symtype(expr); @@ -516,7 +516,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) end canfold[] && return op(args...) end - similarterm(expr, + maketerm(expr, op, args, symtype(expr); diff --git a/test/arrays.jl b/test/arrays.jl index 50c71ff5e..de16924a3 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -73,13 +73,13 @@ end @test getmetadata(unwrap(v[1]), TestMetaT) == 4 end -@testset "similarterm" begin +@testset "maketerm" begin @variables A[1:5, 1:5] B[1:5, 1:5] T = unwrap(3A) - @test isequal(T, similarterm(T, operation(T), arguments(T))) + @test isequal(T, maketerm(T, operation(T), arguments(T))) T2 = unwrap(3B) - @test isequal(T2, similarterm(T, operation(T), [*, 3, unwrap(B)])) + @test isequal(T2, maketerm(T, operation(T), [*, 3, unwrap(B)])) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) From 059d56ce67d22b8c05c558fd8e5a8d4b2f99dbbc Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 30 May 2024 16:54:39 +0200 Subject: [PATCH 05/18] Fix more tests --- src/arrays.jl | 4 ++-- src/diff.jl | 6 +++--- test/arrays.jl | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index cffa99aaf..f72b3cac1 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -611,7 +611,7 @@ function replace_by_scalarizing(ex, dict) simterm = (x, f, args; kws...) -> begin if metadata(x) !== nothing - maketerm(x, f, args; metadata=metadata(x)) + maketerm(typeof(x), f, args, symtype(x), metadata(x)) else f(args...) end @@ -622,7 +622,7 @@ function replace_by_scalarizing(ex, dict) f = operation(x) ff = replace_by_scalarizing(f, dict) if metadata(x) !== nothing - maketerm(x, ff, arguments(x); metadata=metadata(x)) + maketerm(typeof(x), ff, arguments(x), symtype(x), metadata(x)) else ff(arguments(x)...) end diff --git a/src/diff.jl b/src/diff.jl index ca8e960e5..e129a4060 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) nothing end - r = Rewriters.Postwalk(r, maketerm=simterm) + r = Rewriters.Postwalk(r, similarterm=simterm) for ii = 1:length(du) i[] = ii @@ -661,7 +661,7 @@ let end end @rule ~x::issym => 0] - linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_simterm)) + linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm)) global hessian_sparsity @@ -695,7 +695,7 @@ let u = map(value, vars) idx(i) = TermCombination(Set([Dict(i=>1)])) dict = Dict(u .=> idx.(1:length(u))) - f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_simterm)(expr) + f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=basic_simterm)(expr) lp = linearity_propagator(f) S = _sparse(lp, length(u)) S = full ? S : tril(S) diff --git a/test/arrays.jl b/test/arrays.jl index de16924a3..6faf686f3 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -77,9 +77,9 @@ end @variables A[1:5, 1:5] B[1:5, 1:5] T = unwrap(3A) - @test isequal(T, maketerm(T, operation(T), arguments(T))) + @test isequal(T, Symbolics.maketerm(T, operation(T), arguments(T))) T2 = unwrap(3B) - @test isequal(T2, maketerm(T, operation(T), [*, 3, unwrap(B)])) + @test isequal(T2, Symbolics.maketerm(T, operation(T), [*, 3, unwrap(B)])) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) From b1f947a158e881f1a55c6138ac7b3e1e6b13631f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 31 May 2024 06:41:00 +0200 Subject: [PATCH 06/18] update one more term defintion --- src/arrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arrays.jl b/src/arrays.jl index f72b3cac1..44725902e 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -69,7 +69,7 @@ function SymbolicUtils.maketerm(t::ArrayOp, f, args, _symtype = nothing; metadat end if !all(isequal.(args, oldargs)) || !isequal(f, operation(t)) - term = maketerm(t.term, f, args) + term = Symbolics.maketerm(typeof(t.term), f, args, symtype(t.term), nothing) subs = Dict() for (orig, new) in zip(oldargs, args) isequal(orig, new) && continue From 62da14429b77d2383c56b80d7e064839a174604b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 31 May 2024 07:18:47 +0200 Subject: [PATCH 07/18] run all tests --- src/arrays.jl | 2 +- test/runtests.jl | 66 +++++++++++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index 44725902e..a14738e51 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -773,7 +773,7 @@ function scalarize(arr) elseif arr isa Num wrap(scalarize(unwrap(arr))) elseif iscall(arr) && symtype(arr) <: Number - t = maketerm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr)) + t = maketerm(typeof(arr), operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata(arr)) iscall(t) ? scalarize_op(operation(t), t) : t else arr diff --git a/test/runtests.jl b/test/runtests.jl index fbd79cab7..4d848e695 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,38 +18,40 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a @register_symbolic limit(a, N)::Integer if GROUP == "All" || GROUP == "Core" - @safetestset "Struct Test" begin include("struct.jl") end - @safetestset "Macro Test" begin include("macro.jl") end - @safetestset "Arrays" begin include("arrays.jl") end - @safetestset "View-setting" begin include("stencils.jl") end - @safetestset "Complex" begin include("complex.jl") end - @safetestset "Semi-polynomial" begin include("semipoly.jl") end - @safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end - @safetestset "Differentiation Test" begin include("diff.jl") end - @safetestset "Difference Test" begin include("difference.jl") end - @safetestset "Degree Test" begin include("degree.jl") end - @safetestset "Coeff Test" begin include("coeff.jl") end - @safetestset "Parsing Test" begin include("parsing.jl") end - @safetestset "Is Linear or Affine Test" begin include("islinear_affine.jl") end - @safetestset "Linear Solver Test" begin include("linear_solver.jl") end - @safetestset "Algebraic Solver Test" begin include("solver.jl") end - @safetestset "Overloading Test" begin include("overloads.jl") end - @safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end - @safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end - @safetestset "Build Function Test" begin include("build_function.jl") end - @safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end - @safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end - @safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end - VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end - @safetestset "Latexify Test" begin include("latexify.jl") end - @safetestset "Domain Test" begin include("domains.jl") end - @safetestset "SymPy Test" begin include("sympy.jl") end - @safetestset "Inequality Test" begin include("inequality.jl") end - @safetestset "Integral Test" begin include("integral.jl") end - @safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end - @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end - @safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end - @safetestset "Registration without using Test" begin include("registration_without_using.jl") end + @testset begin + @safetestset "Struct Test" begin include("struct.jl") end + @safetestset "Macro Test" begin include("macro.jl") end + @safetestset "Arrays" begin include("arrays.jl") end + @safetestset "View-setting" begin include("stencils.jl") end + @safetestset "Complex" begin include("complex.jl") end + @safetestset "Semi-polynomial" begin include("semipoly.jl") end + @safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end + @safetestset "Differentiation Test" begin include("diff.jl") end + @safetestset "Difference Test" begin include("difference.jl") end + @safetestset "Degree Test" begin include("degree.jl") end + @safetestset "Coeff Test" begin include("coeff.jl") end + @safetestset "Parsing Test" begin include("parsing.jl") end + @safetestset "Is Linear or Affine Test" begin include("islinear_affine.jl") end + @safetestset "Linear Solver Test" begin include("linear_solver.jl") end + @safetestset "Algebraic Solver Test" begin include("solver.jl") end + @safetestset "Overloading Test" begin include("overloads.jl") end + @safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end + @safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end + @safetestset "Build Function Test" begin include("build_function.jl") end + @safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end + @safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end + @safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end + VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end + @safetestset "Latexify Test" begin include("latexify.jl") end + @safetestset "Domain Test" begin include("domains.jl") end + @safetestset "SymPy Test" begin include("sympy.jl") end + @safetestset "Inequality Test" begin include("inequality.jl") end + @safetestset "Integral Test" begin include("integral.jl") end + @safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end + @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end + @safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end + @safetestset "Registration without using Test" begin include("registration_without_using.jl") end + end end if GROUP == "All" || GROUP == "GroebnerExt" From 334663c53bb253d90e7c03fb6f976daa831ec9af Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sat, 1 Jun 2024 23:20:47 -0400 Subject: [PATCH 08/18] maketerm metadata to arg not kwarg --- src/variable.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 8b07d2ffc..10e412d5b 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -491,8 +491,8 @@ function fast_substitute(expr, subs; operator = Nothing) maketerm(expr, op, args, - symtype(expr); - metadata = metadata(expr)) + symtype(expr), + metadata(expr)) end function fast_substitute(expr, pair::Pair; operator = Nothing) a, b = pair @@ -519,8 +519,8 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) maketerm(expr, op, args, - symtype(expr); - metadata = metadata(expr)) + symtype(expr), + metadata(expr)) end function getparent(x, val=_fail) From b2a7b575c5474ed4097497d1ef5fd054847ef1a7 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sat, 1 Jun 2024 23:21:00 -0400 Subject: [PATCH 09/18] Update src/complex.jl --- src/complex.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/complex.jl b/src/complex.jl index e94ff5e63..7037051e5 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -27,7 +27,7 @@ function maketerm(t::ComplexTerm, f, args, symtype; metadata=nothing) if f <: Complex ComplexTerm{real(f)}(args...) else - maketerm(first(args), f, args, symtype; metadata=metadata) + maketerm(first(args), f, args, symtype, metadata) end end From 7676c08d36b9cc94ff9b25d49aaa56d7c8da6977 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sat, 1 Jun 2024 23:21:04 -0400 Subject: [PATCH 10/18] Update src/diff.jl --- src/diff.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diff.jl b/src/diff.jl index e129a4060..616f73ed9 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -541,7 +541,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) J = Int[] - simterm(x, f, args; kw...) = maketerm(x, f, args, symtype(x); kw...) + simterm(x, f, args; metadata = nothing, kw...) = maketerm(x, f, args, symtype(x), metadata; kw...) # This rewriter notes down which u's appear in a # given du (whose index is stored in the `i` Ref) From 2677315f335018f6ae3678bbeb173def164c6f8d Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 1 Jun 2024 23:46:52 -0400 Subject: [PATCH 11/18] some maketerm updates --- src/complex.jl | 4 ++-- src/diff.jl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index 7037051e5..849cbc94b 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -23,11 +23,11 @@ operation(a::ComplexTerm{T}) where T = Complex{T} arguments(a::ComplexTerm) = [a.re, a.im] metadata(a::ComplexTerm) = metadata(a.re) -function maketerm(t::ComplexTerm, f, args, symtype; metadata=nothing) +function maketerm(T::Type{<:ComplexTerm}, f, args, symtype, metadata) if f <: Complex ComplexTerm{real(f)}(args...) else - maketerm(first(args), f, args, symtype, metadata) + maketerm(typeof(first(args)), f, args, symtype, metadata) end end diff --git a/src/diff.jl b/src/diff.jl index 616f73ed9..016fe8c3a 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -541,7 +541,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) J = Int[] - simterm(x, f, args; metadata = nothing, kw...) = maketerm(x, f, args, symtype(x), metadata; kw...) + mkterm(x, f, args, _, m) = maketerm(x, f, args, symtype(x), m) # This rewriter notes down which u's appear in a # given du (whose index is stored in the `i` Ref) @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) nothing end - r = Rewriters.Postwalk(r, similarterm=simterm) + r = Rewriters.Postwalk(r, maketerm=mkterm) for ii = 1:length(du) i[] = ii @@ -633,7 +633,7 @@ end isidx(x) = x isa TermCombination -basic_simterm(t, g, args; kws...) = Term{Any}(g, args) +basic_mkterm(t, g, args, _, m) = metadata(Term{Any}(g, args), m) let # we do this in a let block so that Revise works on the list of rules @@ -661,7 +661,7 @@ let end end @rule ~x::issym => 0] - linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm)) + linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm)) global hessian_sparsity From 5cff3ad9bfdd43dd42ba3bea10d9678258b5c5bb Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 2 Jun 2024 12:05:33 -0400 Subject: [PATCH 12/18] Allow variable number of args in bareterm --- src/semipoly.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/semipoly.jl b/src/semipoly.jl index e4f02298f..6b976a0f2 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff) isop(x, op) = iscall(x) && operation(x) === op isop(op) = Base.Fix2(isop, op) -bareterm(x, f, args; kw...) = Term{symtype(x)}(f, args) +bareterm(x, f, args, as...; kw...) = Term{symtype(x)}(f, args) function mark_and_exponentiate(expr, vars) # Step 1 From 24c5e3e6200158a2fbca69cc7124bbda22dfab8b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 2 Jun 2024 22:04:31 -0400 Subject: [PATCH 13/18] fix some imports --- Project.toml | 1 + src/Symbolics.jl | 5 +++-- src/arrays.jl | 1 + src/diff.jl | 2 +- src/semipoly.jl | 6 ++++-- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 7a9eaae37..7a90761df 100644 --- a/Project.toml +++ b/Project.toml @@ -38,6 +38,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicLimits = "19f23fe9-fdab-4a78-91af-e7b7767979c3" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [weakdeps] Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index d2350bdfd..a20ab7f01 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -19,7 +19,8 @@ using PrecompileTools import DomainSets: Domain - import SymbolicUtils: maketerm, iscall, operation, arguments, symtype, metadata + using TermInterface + import TermInterface: maketerm, iscall, operation, arguments, symtype, metadata import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, FnType, @rule, Rewriters, substitute, @@ -34,7 +35,7 @@ using PrecompileTools import ArrayInterface using RuntimeGeneratedFunctions using SciMLBase, IfElse - using MacroTools + import MacroTools using SymbolicIndexingInterface diff --git a/src/arrays.jl b/src/arrays.jl index a14738e51..88471fe8b 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -147,6 +147,7 @@ macro arrayop(output_idx, expr, options...) call = nothing extra = [] + isexpr = MacroTools.isexpr for o in options if isexpr(o, :call) && o.args[1] == :in push!(rs, :($(o.args[2]) => $(o.args[3]))) diff --git a/src/diff.jl b/src/diff.jl index 016fe8c3a..030b68212 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -695,7 +695,7 @@ let u = map(value, vars) idx(i) = TermCombination(Set([Dict(i=>1)])) dict = Dict(u .=> idx.(1:length(u))) - f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=basic_simterm)(expr) + f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_maketerm)(expr) lp = linearity_propagator(f) S = _sparse(lp, length(u)) S = full ? S : tril(S) diff --git a/src/semipoly.jl b/src/semipoly.jl index e4f02298f..edcd5e267 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff) isop(x, op) = iscall(x) && operation(x) === op isop(op) = Base.Fix2(isop, op) -bareterm(x, f, args; kw...) = Term{symtype(x)}(f, args) +bareterm(T, f, args, _, m) = Term{symtype(T)}(f, args) function mark_and_exponentiate(expr, vars) # Step 1 @@ -424,7 +424,9 @@ function unwrap_sp(m::SemiMonomial) end function unwrap_sp(x) x = unwrap(x) - iscall(x) ? maketerm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x + iscall(x) ? maketerm(typeof(x), + TermInterface.head(x), map(unwrap_sp, + TermInterface.children(x))) : x end function cautious_sum(nls) From 32087bd9f6473bd2ade0baf06979fcac05ad24f2 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 2 Jun 2024 22:28:18 -0400 Subject: [PATCH 14/18] w --- src/diff.jl | 6 ++---- src/semipoly.jl | 4 ++-- src/utils.jl | 5 +++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index 030b68212..b4e7a4c6e 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -541,8 +541,6 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) J = Int[] - mkterm(x, f, args, _, m) = maketerm(x, f, args, symtype(x), m) - # This rewriter notes down which u's appear in a # given du (whose index is stored in the `i` Ref) @@ -552,7 +550,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) nothing end - r = Rewriters.Postwalk(r, maketerm=mkterm) + r = Rewriters.Postwalk(r) for ii = 1:length(du) i[] = ii @@ -695,7 +693,7 @@ let u = map(value, vars) idx(i) = TermCombination(Set([Dict(i=>1)])) dict = Dict(u .=> idx.(1:length(u))) - f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_maketerm)(expr) + f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_mkterm)(expr) lp = linearity_propagator(f) S = _sparse(lp, length(u)) S = full ? S : tril(S) diff --git a/src/semipoly.jl b/src/semipoly.jl index edcd5e267..782cd0bbe 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff) isop(x, op) = iscall(x) && operation(x) === op isop(op) = Base.Fix2(isop, op) -bareterm(T, f, args, _, m) = Term{symtype(T)}(f, args) +simpleterm(T, f, args, sT, m) = Term{sT}(f, args) function mark_and_exponentiate(expr, vars) # Step 1 @@ -153,7 +153,7 @@ function mark_and_exponentiate(expr, vars) @rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs)) @rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, unsorted_arguments(~a))...) @rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)] - expr′ = Postwalk(RestartedChain(rules), maketerm = bareterm)(expr′) + expr′ = Postwalk(RestartedChain(rules), maketerm = simpleterm)(expr′) end function semipolyform_terms(expr, vars) diff --git a/src/utils.jl b/src/utils.jl index cd41ec377..f067a14c6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -113,7 +113,8 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi d_separator = 'ˍ' if ds === nothing - return maketerm(O, operation(O), map(diff2term, arguments(O)), metadata = O_metadata isa Nothing ? + return maketerm(typeof(O), head(O), map(diff2term, children(O)), + symtype(O), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)) else oldop = operation(O) @@ -128,7 +129,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi return Sym{symtype(O)}(Symbol(opname, d_separator, ds)) end newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds) - return setname(maketerm(O, rename(oldop, newname), arguments(O), metadata = O_metadata isa Nothing ? + return setname(maketerm(typeof(O), rename(oldop, newname), children(O), symtype(O), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)), newname) end end From d9b74bbedd90869ab0d5a08b156a29554861fb1c Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 2 Jun 2024 22:57:02 -0400 Subject: [PATCH 15/18] fix up maketerm calls --- src/arrays.jl | 34 ++++++---------------------------- src/variable.jl | 4 ++-- test/arrays.jl | 4 ++-- 3 files changed, 10 insertions(+), 32 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index 88471fe8b..06b37a493 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -62,32 +62,10 @@ end ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} -function SymbolicUtils.maketerm(t::ArrayOp, f, args, _symtype = nothing; metadata = nothing) - oldargs = arguments(t) - if _symtype === nothing - _symtype = symtype(t) - end - - if !all(isequal.(args, oldargs)) || !isequal(f, operation(t)) - term = Symbolics.maketerm(typeof(t.term), f, args, symtype(t.term), nothing) - subs = Dict() - for (orig, new) in zip(oldargs, args) - isequal(orig, new) && continue - subs[orig] = new - end - if !isequal(f, operation(t)) - subs[operation(t)] = f - end - expr = substitute(t.expr, subs) - expr = SymbolicUtils.term(operation(expr), arguments(expr)...) - else - term = t.term - expr = t.expr - end - if _symtype === nothing - _symtype = symtype(t) - end - return ArrayOp{_symtype}(t.output_idx, expr, t.reduce, term, t.shape, t.ranges, metadata) +function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, _symtype, m) + t = f(args...) + t isa Symbolic && !isnothing(metadata) ? + metadata(t, m) : t end shape(aop::ArrayOp) = aop.shape @@ -640,8 +618,8 @@ end function prewalk_if(cond, f, t, maketerm) t′ = cond(t) ? f(t) : return t if iscall(t′) - return maketerm(t′, operation(t′), - map(x->prewalk_if(cond, f, x, maketerm), arguments(t′))) + return maketerm(typeof(t′), TermInterface.head(t′), + map(x->prewalk_if(cond, f, x, maketerm), children(t′))) else return t′ end diff --git a/src/variable.jl b/src/variable.jl index 10e412d5b..7e2a18572 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -488,7 +488,7 @@ function fast_substitute(expr, subs; operator = Nothing) end canfold[] && return op(args...) end - maketerm(expr, + maketerm(typeof(expr), op, args, symtype(expr), @@ -516,7 +516,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) end canfold[] && return op(args...) end - maketerm(expr, + maketerm(typeof(expr), op, args, symtype(expr), diff --git a/test/arrays.jl b/test/arrays.jl index 6faf686f3..563e14e85 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -77,9 +77,9 @@ end @variables A[1:5, 1:5] B[1:5, 1:5] T = unwrap(3A) - @test isequal(T, Symbolics.maketerm(T, operation(T), arguments(T))) + @test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), symtype(T), nothing)) T2 = unwrap(3B) - @test isequal(T2, Symbolics.maketerm(T, operation(T), [*, 3, unwrap(B)])) + @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], symtype(T), nothing)) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) From 990f741832e4a9832370b66c3eb8b957ec47fd70 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 2 Jun 2024 23:36:03 -0400 Subject: [PATCH 16/18] lower bound SU --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7a90761df..2f440ac2f 100644 --- a/Project.toml +++ b/Project.toml @@ -86,7 +86,7 @@ StaticArrays = "1.1" SymPy = "2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.0" -SymbolicUtils = "2" +SymbolicUtils = "2.0.2" julia = "1.10" [extras] From 777b480c1a815c338b11e1b23687747645baf758 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 3 Jun 2024 07:13:57 -0400 Subject: [PATCH 17/18] Update Project.toml --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 9a18ffb63..b5f1b93aa 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,5 +15,5 @@ Latexify = "0.15, 0.16" OrdinaryDiffEq = "6.31" Plots = "1.36" StaticArrays = "1.5" -SymbolicUtils = "1" +SymbolicUtils = "2" Symbolics = "5" From be17457eec7259e8f28c1b6b4aeea2d837f78aee Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 3 Jun 2024 07:46:53 -0400 Subject: [PATCH 18/18] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f440ac2f..6c6f710fb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "5.28.1" +version = "5.29.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"