Skip to content

Commit

Permalink
some maketerm updates
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed Jun 2, 2024
1 parent 7676c08 commit 2677315
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ operation(a::ComplexTerm{T}) where T = Complex{T}
arguments(a::ComplexTerm) = [a.re, a.im]
metadata(a::ComplexTerm) = metadata(a.re)

function maketerm(t::ComplexTerm, f, args, symtype; metadata=nothing)
function maketerm(T::Type{<:ComplexTerm}, f, args, symtype, metadata)
if f <: Complex
ComplexTerm{real(f)}(args...)
else
maketerm(first(args), f, args, symtype, metadata)
maketerm(typeof(first(args)), f, args, symtype, metadata)
end
end

Expand Down
8 changes: 4 additions & 4 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray)
J = Int[]


simterm(x, f, args; metadata = nothing, kw...) = maketerm(x, f, args, symtype(x), metadata; kw...)
mkterm(x, f, args, _, m) = maketerm(x, f, args, symtype(x), m)

# This rewriter notes down which u's appear in a
# given du (whose index is stored in the `i` Ref)
Expand All @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray)
nothing
end

r = Rewriters.Postwalk(r, similarterm=simterm)
r = Rewriters.Postwalk(r, maketerm=mkterm)

for ii = 1:length(du)
i[] = ii
Expand Down Expand Up @@ -633,7 +633,7 @@ end

isidx(x) = x isa TermCombination

basic_simterm(t, g, args; kws...) = Term{Any}(g, args)
basic_mkterm(t, g, args, _, m) = metadata(Term{Any}(g, args), m)

let
# we do this in a let block so that Revise works on the list of rules
Expand Down Expand Up @@ -661,7 +661,7 @@ let
end
end
@rule ~x::issym => 0]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm))
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm))

global hessian_sparsity

Expand Down

0 comments on commit 2677315

Please sign in to comment.