From 914c3152ae17abcbf28c99e9cbf842ffe4ea16d8 Mon Sep 17 00:00:00 2001 From: Aaron Kaw Date: Wed, 4 Sep 2024 08:41:50 +1000 Subject: [PATCH 1/4] Passing a Function to derivative now feeds the Num input to the function --- src/diff.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diff.jl b/src/diff.jl index d9d4c1bbf..a368e9c11 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -371,6 +371,8 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1 derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat!(collect(args), i)...) derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0 +derivative(f::Function, x::Num) = derivative(f(x), x) + function count_order(x) @assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!" n = 1 From b7d75d0e988ca7163eab341c2c0f856266e1190e Mon Sep 17 00:00:00 2001 From: Aaron Kaw Date: Wed, 4 Sep 2024 08:50:34 +1000 Subject: [PATCH 2/4] Added tests to check if function inputs to derivative behave as expected --- test/diff.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/diff.jl b/test/diff.jl index 88956f138..23d9e837f 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -378,4 +378,15 @@ let Dt = Differential(t)^0 @test isequal(Dt, identity) test_equal(Dt(t + 2t^2), t + 2t^2) -end \ No newline at end of file +end + +# Check `Function` inputs for derivative (#1085) +let + @variables x + @testset for f in [sqrt, sin, acos, exp, cis] + @test isequal( + Symbolics.derivative(f, x), + Symbolics.derivative(f(x), x) + ) + end +end From 4d74c67546408fe93659faa9f9086be1447fa2c7 Mon Sep 17 00:00:00 2001 From: Aaron Kaw Date: Wed, 4 Sep 2024 09:10:15 +1000 Subject: [PATCH 3/4] derivative now throws for Function first input with non-Num second input --- src/diff.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diff.jl b/src/diff.jl index a368e9c11..891e2cde2 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -372,6 +372,7 @@ derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat! derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0 derivative(f::Function, x::Num) = derivative(f(x), x) +derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Num, typeof(x)) |> throw function count_order(x) @assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!" From 6fcede0ce0c012c9481f2a1c500c6fc69a6ba9e5 Mon Sep 17 00:00:00 2001 From: Aaron Kaw Date: Wed, 4 Sep 2024 09:23:11 +1000 Subject: [PATCH 4/4] Test derivative throws for Function then non-Num input --- test/diff.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/diff.jl b/test/diff.jl index 23d9e837f..eb500cb9d 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -390,3 +390,11 @@ let ) end end + +# Check `Function` inputs throw for non-Num second input (#1085) +let + @testset for f in [sqrt, sin, acos, exp, cis] + @test_throws TypeError Symbolics.derivative(f, rand()) + @test_throws TypeError Symbolics.derivative(f, Val(rand(Int))) + end +end