diff --git a/src/utils.jl b/src/utils.jl index c6cafdb00..61458dc9e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -315,15 +315,36 @@ julia> Symbolics.coeff(3x + 2y, y) julia> Symbolics.coeff(x^2 + y, x^2) 1 + +julia> Symbolics.coeff(2*x*y + y, x*y) +2 ``` """ function coeff(p, sym=nothing) + # if `sym` is a product, iteratively compute the coefficient w.r.t. each term in `sym` + if iscall(value(sym)) && operation(value(sym)) === (*) + for t in arguments(value(sym)) + @assert !(t isa Number) "`coeff(p, sym)` does not allow `sym` containing numerical factors" + p = coeff(p, t) + end + return p + end + p, sym = value(p), value(sym) if isequal(sym, 1) sym = nothing end + return _coeff(p, sym) +end + +""" + _coeff(p, sym) + +Function used internally by `coeff(p, sym)`, after the latter function performs some initial steps and re-assigns `p, sym = value(p), value(sym)` +""" +function _coeff(p, sym) if issym(p) || isterm(p) sym === nothing ? 0 : Int(isequal(p, sym)) elseif ispow(p) diff --git a/test/coeff.jl b/test/coeff.jl index 5ae4b71bd..ea5ad7d4a 100644 --- a/test/coeff.jl +++ b/test/coeff.jl @@ -54,3 +54,9 @@ e = x*y^2 + 2x + y^3*x^3 @test isequal(coeff(x / 5, x), 1//5) @test isequal(coeff(x / y, x), 1/y) @test isequal(coeff(x * 5y / (1 + y + z) , x), 5y / (1 + y + z)) + +# issue #1041 - coefficient of cross term in multivariate polynomial + +@test isequal(coeff(2*x*y + y, x*y), 2) +@test isequal(coeff(2*x^2*y + y, x^2*y), 2) +@test_throws AssertionError coeff(2*x*y + y, 2*x*y) # numerical factors not allowed in second argument of `coeff`