Skip to content

Commit

Permalink
Fix coeff(p, sym) when sym is a product, by iteratively computing…
Browse files Browse the repository at this point in the history
… the coefficient w.r.t. each term in the product.
  • Loading branch information
zengmao committed Aug 26, 2024
1 parent f6818bf commit bd14f03
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,36 @@ julia> Symbolics.coeff(3x + 2y, y)
julia> Symbolics.coeff(x^2 + y, x^2)
1
julia> Symbolics.coeff(2*x*y + y, x*y)
2
```
"""
function coeff(p, sym=nothing)
# if `sym` is a product, iteratively compute the coefficient w.r.t. each term in `sym`
if iscall(value(sym)) && operation(value(sym)) === (*)
for t in arguments(value(sym))
@assert !(t isa Number) "`coeff(p, sym)` does not allow `sym` containing numerical factors"
p = coeff(p, t)
end
return p
end

p, sym = value(p), value(sym)

if isequal(sym, 1)
sym = nothing
end

return _coeff(p, sym)
end

"""
_coeff(p, sym)
Function used internally by `coeff(p, sym)`, after the latter function performs some initial steps and re-assigns `p, sym = value(p), value(sym)`
"""
function _coeff(p, sym)
if issym(p) || isterm(p)
sym === nothing ? 0 : Int(isequal(p, sym))
elseif ispow(p)
Expand Down
6 changes: 6 additions & 0 deletions test/coeff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,9 @@ e = x*y^2 + 2x + y^3*x^3
@test isequal(coeff(x / 5, x), 1//5)
@test isequal(coeff(x / y, x), 1/y)
@test isequal(coeff(x * 5y / (1 + y + z) , x), 5y / (1 + y + z))

# issue #1041 - coefficient of cross term in multivariate polynomial

@test isequal(coeff(2*x*y + y, x*y), 2)
@test isequal(coeff(2*x^2*y + y, x^2*y), 2)
@test_throws AssertionError coeff(2*x*y + y, 2*x*y) # numerical factors not allowed in second argument of `coeff`

0 comments on commit bd14f03

Please sign in to comment.