From 418cc1865a1d7ebc67e953bbbb79ca8f059453d8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 11 Jan 2022 10:19:23 -0500 Subject: [PATCH 1/2] add is_non_differentiable --- src/ChainRulesCore.jl | 2 +- src/projection.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..6962491fc 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -13,7 +13,7 @@ export frule_via_ad, rrule_via_ad export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented export ProjectTo, canonicalize, unthunk # tangent operations export add!! # gradient accumulation operations -export ignore_derivatives, @ignore_derivatives +export ignore_derivatives, @ignore_derivatives, is_non_differentiable # tangents export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..e97594fe5 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -142,16 +142,42 @@ end # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through: (::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx +##### +##### A related utility which wants to live nearby +##### + +""" + is_non_differentiable(x) == is_non_differentiable(typeof(x)) + +Returns `true` if `x` is known from its type not to have derivatives, else `false`. + +Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`, +which is what the fallback method checks. The exception is that it will not look +inside abstractly typed containers like `x = Any[true, false]`. +""" +is_non_differentiable(x) = is_non_differentiable(typeof(x)) + +is_non_differentiable(::Type{<:Number}) = false +is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T) +is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T) + +function is_non_differentiable(::Type{T}) where {T} # fallback + PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable + return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero} +end + ##### ##### `Base` ##### # Bool ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above +is_non_differentiable(::Type{Bool}) = true # Other never-differentiable types for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle) @eval ProjectTo(::$T) = ProjectTo{NoTangent}() + @eval is_non_differentiable(::Type{<:$T}) = true end # Numbers From 7c8e040edcf44182d05cf177e7dccc9d986f5466 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 7 Jul 2022 14:42:07 -0600 Subject: [PATCH 2/2] change to return the type --- src/ChainRulesCore.jl | 4 +-- src/projection.jl | 67 ++++++++++++++++++++++++++----------------- test/projection.jl | 10 +++++++ 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 6962491fc..debedca49 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -11,9 +11,9 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # tangent operations +export ProjectTo, differential_type, canonicalize, unthunk # tangent operations export add!! # gradient accumulation operations -export ignore_derivatives, @ignore_derivatives, is_non_differentiable +export ignore_derivatives, @ignore_derivatives # tangents export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/projection.jl b/src/projection.jl index e97594fe5..c8f3b93c3 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p)) backing(project::ProjectTo) = getfield(project, :info) project_type(p::ProjectTo{T}) where {T} = T +project_type(::Type{<:ProjectTo{T}}) where {T} = T +project_type(_) = Any function Base.show(io::IO, project::ProjectTo{T}) where {T} print(io, "ProjectTo{") @@ -142,42 +144,16 @@ end # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through: (::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx -##### -##### A related utility which wants to live nearby -##### - -""" - is_non_differentiable(x) == is_non_differentiable(typeof(x)) - -Returns `true` if `x` is known from its type not to have derivatives, else `false`. - -Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`, -which is what the fallback method checks. The exception is that it will not look -inside abstractly typed containers like `x = Any[true, false]`. -""" -is_non_differentiable(x) = is_non_differentiable(typeof(x)) - -is_non_differentiable(::Type{<:Number}) = false -is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T) -is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T) - -function is_non_differentiable(::Type{T}) where {T} # fallback - PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable - return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero} -end - ##### ##### `Base` ##### # Bool ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above -is_non_differentiable(::Type{Bool}) = true # Other never-differentiable types -for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle) +for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle, :Nothing) @eval ProjectTo(::$T) = ProjectTo{NoTangent}() - @eval is_non_differentiable(::Type{<:$T}) = true end # Numbers @@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) invoke(project, Tuple{AbstractArray}, dx) end end + +##### +##### A related utility which wants to live nearby +##### + +""" + differential_type(x) + differential_type(typeof(x)) + +Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is +known to be non-differentiable. + +This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference. +Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`. + +```jldoctest +julia> differential_type(true) +NoTangent + +julia> differential_type(Int) +Float64 + +julia> x = Any[true, false]; + +julia> differential_type(x) +NoTangent + +julia> differential_type(typeof(x)) +Any +``` +""" +differential_type(x) = project_type(ProjectTo(x)) + +function differential_type(::Type{T}) where {T} + PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable + return isconcretetype(PT) ? project_type(PT) : Any +end diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..1f50c6aa4 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -478,3 +478,13 @@ struct NoSuperType end @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end + +@testset "differential_type" begin + @test differential_type(true) == differential_type(Bool) == NoTangent + @test differential_type(1) == differential_type(Int) == Float64 + tup = (false, :x, nothing) + @test differential_type(tup) == differential_type(typeof(tup)) == NoTangent + + @test differential_type(NoSuperType()) == differential_type(NoSuperType) == Any + @test differential_type(Dual(1,2)) == differential_type(Dual) == Real +end \ No newline at end of file