diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 4f934ae2f..040f14f72 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -1,4 +1,37 @@ -@register_symbolic Base.binomial(n, k::Integer)::Int true +@register_symbolic Base.binomial(n, k)::Int true +function _binomial(nothing, n, k) + begin + args = [n, k] + unwrapped_args = map(Symbolics.unwrap, args) + res = if !(any((x->begin + SymbolicUtils.issym(x) || SymbolicUtils.istree(x) + end), unwrapped_args)) + Base.binomial(unwrapped_args...) + else + SymbolicUtils.Term{Int}(Base.binomial, unwrapped_args) + end + if typeof.(args) == typeof.(unwrapped_args) + return res + else + return Symbolics.wrap(res) + end + end +end + +for (T1, T2) in ((Symbolics.SymbolicUtils.Symbolic{<:Real}, Int64), + (Num, Int64), + (Real, Symbolics.SymbolicUtils.Symbolic{<:Int64}), + (Symbolics.SymbolicUtils.Symbolic{<:Real}, Symbolics.SymbolicUtils.Symbolic{<:Int64}), + (Num, Symbolics.SymbolicUtils.Symbolic{<:Int64})) + + @eval function Base.binomial(n::$T1, k::$T2) + if any(Symbolics.iswrapped, (n, k)) + Symbolics.wrap(_binomial(nothing, Symbolics.unwrap(n), Symbolics.unwrap(k))) + else + _binomial(nothing, n, k) + end + end +end @register_symbolic Base.sign(x)::Int derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0 diff --git a/test/overloads.jl b/test/overloads.jl index a32b39c54..84fcc3ecd 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -237,3 +237,18 @@ for f in [<, <=, >, >=, isless] end @test_nowarn binomial(t, 1) + +# test for https://github.com/JuliaSymbolics/Symbolics.jl/issues/1028 +let + @variables t A(t) B + @test try binomial(A, 2*B^2) + true + catch + false + end + @test try binomial(Symbolics.value(A), Symbolics.value(2*B^2)) + true + catch + false + end +end