Skip to content

Commit

Permalink
Create ZeroOperator and make D(const) a hard zero
Browse files Browse the repository at this point in the history
This fixes an odd bug that can occur when a constant can match, like `Differential(2)(2x) == x`, when it should obviously be zero. This fixes Differential(number) == ZeroOperator which acts like a hard zero. You need ZeroOperator instead of 0 because otherwise you codegen to stuff like 0(2x) and get an error that 0 does not have a call, and `zero` the function does not do things like `zero * zero` appropriately, so constructing a simple operator that acts like a true zero is the simplest solution that enforces the safety.
  • Loading branch information
ChrisRackauckas committed Sep 19, 2024
1 parent ee77c8a commit 0a2643a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/diff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
abstract type Operator end
propagate_shape(::Operator, x) = axes(x)

struct ZeroOperator <: Operator
end
(::ZeroOperator)(x) = wrap(0)
Base.show(io::IO, D::Differential) = print(io, "ZeroOperator")
Base.nameof(D::ZeroOperator) = :ZeroOperator
Base.:*(D1, D2::ZeroOperator) = ZeroOperator()
Base.:*(D1::ZeroOperator, D2) = ZeroOperator()
Base.:*(D1::ZeroOperator, D2::ZeroOperator) = ZeroOperator()
Base.:^(D::ZeroOperator, n::Integer) = ZeroOperator()

"""
$(TYPEDEF)
Expand Down Expand Up @@ -33,6 +43,7 @@ struct Differential <: Operator
"""The variable or expression to differentiate with respect to."""
x
Differential(x) = new(value(x))
Differential(x::Union{AbstractFloat, Integer}) = ZeroOperator()
end
function (D::Differential)(x)
x = unwrap(x)
Expand All @@ -42,6 +53,8 @@ function (D::Differential)(x)
term(D, x)
end
end

(D::Differential)(x::Union{AbstractFloat, Integers}) = wrap(0)
(D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x)))
(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x)))))
SymbolicUtils.promote_symtype(::Differential, T) = T
Expand Down
7 changes: 7 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ Dx = Differential(x)

test_equal(a, b) = @test isequal(simplify(a), simplify(b))

@testset "ZeroOperator handling" begin
@test isequal(Differential(0.1)(x), 0)
@test isequal(Differential(0.1)(0.1x), 0)
@test isequal(Differential(1)(x), 0)
@test isequal(Differential(2)(2x), 0)
end

#@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))

@test isequal(expand_derivatives(D(t)), 1)
Expand Down

0 comments on commit 0a2643a

Please sign in to comment.