Skip to content

Commit

Permalink
fixes in polyform
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed May 30, 2024
1 parent 8f22c01 commit 9d4562d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
# create a new symbol to store this

y = if recurse
similarterm(x,
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse),
args), symtype(x))
maketerm(typeof(x),
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args),
symtype(x),
metadata(x))
else
x
end
Expand Down Expand Up @@ -175,7 +176,7 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true

function similarterm(::Type{<:PolyForm}, f, args, symtype, metadata)
function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata)
basicsymbolic(t, f, args, symtype, metadata)
end
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
Expand Down Expand Up @@ -248,8 +249,10 @@ multivariate polynomials implementation.
expand(expr) = unpolyize(PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true))

function unpolyize(x)
simterm(x, f, args; kw...) = similarterm(x, f, args, symtype(x); kw...)
Postwalk(identity, similarterm=simterm)(x)
# we need a special makterm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x)
end

function toterm(x::PolyForm)
Expand Down Expand Up @@ -301,7 +304,7 @@ function add_divs(x, y)
end
end

function frac_similarterm(x, f, args; kw...)
function frac_maketerm(T, f, args, stype, metadata)
if f in (*, /, \, +, -)
f(args...)
elseif f == (^)
Expand All @@ -311,7 +314,7 @@ function frac_similarterm(x, f, args; kw...)
args[1]^args[2]
end
else
similarterm(x, f, args; kw...)
maketerm(T, f, args, stype, metadata)
end
end

Expand All @@ -333,8 +336,8 @@ function simplify_fractions(x; polyform=false)
sdiv(a) = isdiv(a) ? simplify_div(a) : a

expr = Postwalk(sdiv quick_cancel,
similarterm=frac_similarterm)(Postwalk(add_with_div,
similarterm=frac_similarterm)(x))
maketerm=frac_maketerm)(Postwalk(add_with_div,
maketerm=frac_maketerm)(x))

polyform ? expr : unpolyize(expr)
end
Expand Down

0 comments on commit 9d4562d

Please sign in to comment.