diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..debedca49 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -11,7 +11,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # tangent operations +export ProjectTo, differential_type, canonicalize, unthunk # tangent operations export add!! # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..c8f3b93c3 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p)) backing(project::ProjectTo) = getfield(project, :info) project_type(p::ProjectTo{T}) where {T} = T +project_type(::Type{<:ProjectTo{T}}) where {T} = T +project_type(_) = Any function Base.show(io::IO, project::ProjectTo{T}) where {T} print(io, "ProjectTo{") @@ -150,7 +152,7 @@ end ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above # Other never-differentiable types -for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle) +for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle, :Nothing) @eval ProjectTo(::$T) = ProjectTo{NoTangent}() end @@ -601,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) invoke(project, Tuple{AbstractArray}, dx) end end + +##### +##### A related utility which wants to live nearby +##### + +""" + differential_type(x) + differential_type(typeof(x)) + +Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is +known to be non-differentiable. + +This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference. +Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`. + +```jldoctest +julia> differential_type(true) +NoTangent + +julia> differential_type(Int) +Float64 + +julia> x = Any[true, false]; + +julia> differential_type(x) +NoTangent + +julia> differential_type(typeof(x)) +Any +``` +""" +differential_type(x) = project_type(ProjectTo(x)) + +function differential_type(::Type{T}) where {T} + PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable + return isconcretetype(PT) ? project_type(PT) : Any +end diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..1f50c6aa4 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -478,3 +478,13 @@ struct NoSuperType end @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end + +@testset "differential_type" begin + @test differential_type(true) == differential_type(Bool) == NoTangent + @test differential_type(1) == differential_type(Int) == Float64 + tup = (false, :x, nothing) + @test differential_type(tup) == differential_type(typeof(tup)) == NoTangent + + @test differential_type(NoSuperType()) == differential_type(NoSuperType) == Any + @test differential_type(Dual(1,2)) == differential_type(Dual) == Real +end \ No newline at end of file