From 2677315f335018f6ae3678bbeb173def164c6f8d Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 1 Jun 2024 23:46:52 -0400 Subject: [PATCH] some maketerm updates --- src/complex.jl | 4 ++-- src/diff.jl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index 7037051e5..849cbc94b 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -23,11 +23,11 @@ operation(a::ComplexTerm{T}) where T = Complex{T} arguments(a::ComplexTerm) = [a.re, a.im] metadata(a::ComplexTerm) = metadata(a.re) -function maketerm(t::ComplexTerm, f, args, symtype; metadata=nothing) +function maketerm(T::Type{<:ComplexTerm}, f, args, symtype, metadata) if f <: Complex ComplexTerm{real(f)}(args...) else - maketerm(first(args), f, args, symtype, metadata) + maketerm(typeof(first(args)), f, args, symtype, metadata) end end diff --git a/src/diff.jl b/src/diff.jl index 616f73ed9..016fe8c3a 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -541,7 +541,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) J = Int[] - simterm(x, f, args; metadata = nothing, kw...) = maketerm(x, f, args, symtype(x), metadata; kw...) + mkterm(x, f, args, _, m) = maketerm(x, f, args, symtype(x), m) # This rewriter notes down which u's appear in a # given du (whose index is stored in the `i` Ref) @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray) nothing end - r = Rewriters.Postwalk(r, similarterm=simterm) + r = Rewriters.Postwalk(r, maketerm=mkterm) for ii = 1:length(du) i[] = ii @@ -633,7 +633,7 @@ end isidx(x) = x isa TermCombination -basic_simterm(t, g, args; kws...) = Term{Any}(g, args) +basic_mkterm(t, g, args, _, m) = metadata(Term{Any}(g, args), m) let # we do this in a let block so that Revise works on the list of rules @@ -661,7 +661,7 @@ let end end @rule ~x::issym => 0] - linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm)) + linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm)) global hessian_sparsity