Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 30, 2024
1 parent e1a6422 commit 059d56c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ function replace_by_scalarizing(ex, dict)

simterm = (x, f, args; kws...) -> begin
if metadata(x) !== nothing
maketerm(x, f, args; metadata=metadata(x))
maketerm(typeof(x), f, args, symtype(x), metadata(x))
else
f(args...)
end
Expand All @@ -622,7 +622,7 @@ function replace_by_scalarizing(ex, dict)
f = operation(x)
ff = replace_by_scalarizing(f, dict)
if metadata(x) !== nothing
maketerm(x, ff, arguments(x); metadata=metadata(x))
maketerm(typeof(x), ff, arguments(x), symtype(x), metadata(x))
else
ff(arguments(x)...)
end
Expand Down
6 changes: 3 additions & 3 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ function jacobian_sparsity(exprs::AbstractArray, vars::AbstractArray)
nothing
end

r = Rewriters.Postwalk(r, maketerm=simterm)
r = Rewriters.Postwalk(r, similarterm=simterm)

for ii = 1:length(du)
i[] = ii
Expand Down Expand Up @@ -661,7 +661,7 @@ let
end
end
@rule ~x::issym => 0]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_simterm))
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm))

global hessian_sparsity

Expand Down Expand Up @@ -695,7 +695,7 @@ let
u = map(value, vars)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_simterm)(expr)
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=basic_simterm)(expr)
lp = linearity_propagator(f)
S = _sparse(lp, length(u))
S = full ? S : tril(S)
Expand Down
4 changes: 2 additions & 2 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ end
@variables A[1:5, 1:5] B[1:5, 1:5]

T = unwrap(3A)
@test isequal(T, maketerm(T, operation(T), arguments(T)))
@test isequal(T, Symbolics.maketerm(T, operation(T), arguments(T)))
T2 = unwrap(3B)
@test isequal(T2, maketerm(T, operation(T), [*, 3, unwrap(B)]))
@test isequal(T2, Symbolics.maketerm(T, operation(T), [*, 3, unwrap(B)]))
end

getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
Expand Down

0 comments on commit 059d56c

Please sign in to comment.