Skip to content

Commit

Permalink
change to return the type
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 7, 2022
1 parent 418cc18 commit 7c8e040
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # tangent operations
export ProjectTo, differential_type, canonicalize, unthunk # tangent operations
export add!! # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives, is_non_differentiable
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk

Expand Down
67 changes: 40 additions & 27 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
backing(project::ProjectTo) = getfield(project, :info)

project_type(p::ProjectTo{T}) where {T} = T
project_type(::Type{<:ProjectTo{T}}) where {T} = T
project_type(_) = Any

function Base.show(io::IO, project::ProjectTo{T}) where {T}
print(io, "ProjectTo{")
Expand Down Expand Up @@ -142,42 +144,16 @@ end
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx

#####
##### A related utility which wants to live nearby
#####

"""
is_non_differentiable(x) == is_non_differentiable(typeof(x))
Returns `true` if `x` is known from its type not to have derivatives, else `false`.
Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
which is what the fallback method checks. The exception is that it will not look
inside abstractly typed containers like `x = Any[true, false]`.
"""
is_non_differentiable(x) = is_non_differentiable(typeof(x))

is_non_differentiable(::Type{<:Number}) = false
is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T)
is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T)

function is_non_differentiable(::Type{T}) where {T} # fallback
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero}
end

#####
##### `Base`
#####

# Bool
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
is_non_differentiable(::Type{Bool}) = true

# Other never-differentiable types
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle, :Nothing)
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
@eval is_non_differentiable(::Type{<:$T}) = true
end

# Numbers
Expand Down Expand Up @@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
invoke(project, Tuple{AbstractArray}, dx)
end
end

#####
##### A related utility which wants to live nearby
#####

"""
differential_type(x)
differential_type(typeof(x))
Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is
known to be non-differentiable.
This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference.
Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`.
```jldoctest
julia> differential_type(true)
NoTangent
julia> differential_type(Int)
Float64
julia> x = Any[true, false];
julia> differential_type(x)
NoTangent
julia> differential_type(typeof(x))
Any
```
"""
differential_type(x) = project_type(ProjectTo(x))

function differential_type(::Type{T}) where {T}
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
return isconcretetype(PT) ? project_type(PT) : Any
end
10 changes: 10 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,13 @@ struct NoSuperType end
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
end
end

@testset "differential_type" begin
@test differential_type(true) == differential_type(Bool) == NoTangent
@test differential_type(1) == differential_type(Int) == Float64
tup = (false, :x, nothing)
@test differential_type(tup) == differential_type(typeof(tup)) == NoTangent

@test differential_type(NoSuperType()) == differential_type(NoSuperType) == Any
@test differential_type(Dual(1,2)) == differential_type(Dual) == Real
end

0 comments on commit 7c8e040

Please sign in to comment.