Skip to content

Commit

Permalink
Add promote_typejoin() and use it instead of typejoin() to preserve U…
Browse files Browse the repository at this point in the history
…nion{T, Nothing/Missing} (#25553)

* Introduce promote_typejoin() and use it instead of typejoin() where appropriate

typejoin() is now used only for inference, and promote_typejoin() everywhere else
to choose the appropriate element type for a collection.
  • Loading branch information
nalimilan authored and JeffBezanson committed Jan 27, 2018
1 parent eea727c commit f58d2cf
Show file tree
Hide file tree
Showing 17 changed files with 154 additions and 35 deletions.
4 changes: 2 additions & 2 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T
@inbounds dest[i] = el::T
i += 1
else
R = typejoin(T, S)
R = promote_typejoin(T, S)
new = similar(dest, R)
copyto!(new,1, dest,1, i-1)
@inbounds new[i] = el
Expand Down Expand Up @@ -608,7 +608,7 @@ function grow_to!(dest, itr, st)
if S === T || S <: T
push!(dest, el::T)
else
new = sizehint!(empty(dest, typejoin(T, S)), length(dest))
new = sizehint!(empty(dest, promote_typejoin(T, S)), length(dest))
if new isa AbstractSet
# TODO: merge back these two branches when copy! is re-enabled for sets/vectors
union!(new, dest)
Expand Down
4 changes: 2 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Broadcast
using Base.Cartesian
using Base: Indices, OneTo, linearindices, tail, to_shape,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache,
isoperator
isoperator, promote_typejoin
import Base: broadcast, broadcast!
export BroadcastStyle, broadcast_indices, broadcast_similar,
broadcast_getindex, broadcast_setindex!, dotview, @__dot__
Expand Down Expand Up @@ -507,7 +507,7 @@ end
else
# This element type doesn't fit in B. Allocate a new B with wider eltype,
# copy over old values, and continue
newB = Base.similar(B, typejoin(eltype(B), S))
newB = Base.similar(B, promote_typejoin(eltype(B), S))
for II in Iterators.take(iter, count)
newB[II] = B[II]
end
Expand Down
2 changes: 1 addition & 1 deletion base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function grow_to!(dest::AbstractDict{K,V}, itr, st) where V where K
if isa(k,K) && isa(v,V)
dest[k] = v
else
new = empty(dest, typejoin(K,typeof(k)), typejoin(V,typeof(v)))
new = empty(dest, promote_typejoin(K,typeof(k)), promote_typejoin(V,typeof(v)))
merge!(new, dest)
new[k] = v
return grow_to!(new, itr, st)
Expand Down
18 changes: 12 additions & 6 deletions base/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ nonmissingtype(::Type{Missing}) = Union{}
nonmissingtype(::Type{T}) where {T} = T
nonmissingtype(::Type{Any}) = Any

promote_rule(::Type{Missing}, ::Type{T}) where {T} = Union{T, Missing}
promote_rule(::Type{Union{S,Missing}}, ::Type{T}) where {T,S} = Union{promote_type(T, S), Missing}
promote_rule(::Type{Any}, ::Type{T}) where {T} = Any
promote_rule(::Type{Any}, ::Type{Missing}) = Any
promote_rule(::Type{Missing}, ::Type{Any}) = Any
promote_rule(::Type{Missing}, ::Type{Missing}) = Missing
for U in (:Nothing, :Missing)
@eval begin
promote_rule(::Type{$U}, ::Type{T}) where {T} = Union{T, $U}
promote_rule(::Type{Union{S,$U}}, ::Type{T}) where {T,S} = Union{promote_type(T, S), $U}
promote_rule(::Type{Any}, ::Type{$U}) = Any
promote_rule(::Type{$U}, ::Type{Any}) = Any
promote_rule(::Type{$U}, ::Type{$U}) = U
end
end
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{Any}) = Any
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{T}) where {T} =
Union{Nothing, Missing, T}

convert(::Type{Union{T, Missing}}, x) where {T} = convert(T, x)
# To fix ambiguities
Expand Down
3 changes: 3 additions & 0 deletions base/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ indexed_next(t::NamedTuple, i::Int, state) = (getfield(t, i), i+1)
isempty(::NamedTuple{()}) = true
isempty(::NamedTuple) = false

promote_typejoin(::Type{NamedTuple{n, S}}, ::Type{NamedTuple{n, T}}) where {n, S, T} =
NamedTuple{n, promote_typejoin(S, T)}

convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T} = nt
convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt

Expand Down
69 changes: 50 additions & 19 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,66 @@
"""
typejoin(T, S)
Compute a type that contains both `T` and `S`.
Return the closest common ancestor of `T` and `S`, i.e. the narrowest type from which
they both inherit.
"""
typejoin() = (@_pure_meta; Bottom)
typejoin(@nospecialize(t)) = (@_pure_meta; t)
typejoin(@nospecialize(t), ts...) = (@_pure_meta; typejoin(t, typejoin(ts...)))
function typejoin(@nospecialize(a), @nospecialize(b))
@_pure_meta
typejoin(@nospecialize(a), @nospecialize(b)) = (@_pure_meta; join_types(a, b, typejoin))

function join_types(@nospecialize(a), @nospecialize(b), f::Function)
if a <: b
return b
elseif b <: a
return a
elseif isa(a,UnionAll)
return UnionAll(a.var, typejoin(a.body, b))
return UnionAll(a.var, join_types(a.body, b, f))
elseif isa(b,UnionAll)
return UnionAll(b.var, typejoin(a, b.body))
return UnionAll(b.var, join_types(a, b.body, f))
elseif isa(a,TypeVar)
return typejoin(a.ub, b)
return f(a.ub, b)
elseif isa(b,TypeVar)
return typejoin(a, b.ub)
return f(a, b.ub)
elseif isa(a,Union)
return typejoin(typejoin(a.a,a.b), b)
a′ = f(a.a,a.b)
return a′ === a ? typejoin(a, b) : f(a′, b)
elseif isa(b,Union)
return typejoin(a, typejoin(b.a,b.b))
b′ = f(b.a,b.b)
return b′ === b ? typejoin(a, b) : f(a, b′)
elseif a <: Tuple
if !(b <: Tuple)
return Any
end
ap, bp = a.parameters, b.parameters
lar = length(ap)::Int; lbr = length(bp)::Int
if lar == 0
return Tuple{Vararg{tailjoin(bp,1)}}
return Tuple{Vararg{tailjoin(bp,1,f)}}
end
if lbr == 0
return Tuple{Vararg{tailjoin(ap,1)}}
return Tuple{Vararg{tailjoin(ap,1,f)}}
end
laf, afixed = full_va_len(ap)
lbf, bfixed = full_va_len(bp)
if laf < lbf
if isvarargtype(ap[lar]) && !afixed
c = Vector{Any}(uninitialized, laf)
c[laf] = Vararg{typejoin(unwrapva(ap[lar]), tailjoin(bp,laf))}
c[laf] = Vararg{f(unwrapva(ap[lar]), tailjoin(bp,laf,f))}
n = laf-1
else
c = Vector{Any}(uninitialized, laf+1)
c[laf+1] = Vararg{tailjoin(bp,laf+1)}
c[laf+1] = Vararg{tailjoin(bp,laf+1,f)}
n = laf
end
elseif lbf < laf
if isvarargtype(bp[lbr]) && !bfixed
c = Vector{Any}(uninitialized, lbf)
c[lbf] = Vararg{typejoin(unwrapva(bp[lbr]), tailjoin(ap,lbf))}
c[lbf] = Vararg{f(unwrapva(bp[lbr]), tailjoin(ap,lbf,f))}
n = lbf-1
else
c = Vector{Any}(uninitialized, lbf+1)
c[lbf+1] = Vararg{tailjoin(ap,lbf+1)}
c[lbf+1] = Vararg{tailjoin(ap,lbf+1,f)}
n = lbf
end
else
Expand All @@ -68,7 +73,7 @@ function typejoin(@nospecialize(a), @nospecialize(b))
end
for i = 1:n
ai = ap[min(i,lar)]; bi = bp[min(i,lbr)]
ci = typejoin(unwrapva(ai),unwrapva(bi))
ci = f(unwrapva(ai),unwrapva(bi))
c[i] = i == length(c) && (isvarargtype(ai) || isvarargtype(bi)) ? Vararg{ci} : ci
end
return Tuple{c...}
Expand Down Expand Up @@ -102,6 +107,28 @@ function typejoin(@nospecialize(a), @nospecialize(b))
return Any
end

"""
promote_typejoin(T, S)
Compute a type that contains both `T` and `S`, which could be
either a parent of both types, or a `Union` if appropriate.
Falls back to [`typejoin`](@ref).
"""
promote_typejoin(@nospecialize(a), @nospecialize(b)) =
(@_pure_meta; join_types(a, b, promote_typejoin))
promote_typejoin(::Type{Nothing}, ::Type{T}) where {T} =
isconcretetype(T) ? Union{T, Nothing} : Any
promote_typejoin(::Type{T}, ::Type{Nothing}) where {T} =
isconcretetype(T) ? Union{T, Nothing} : Any
promote_typejoin(::Type{Missing}, ::Type{T}) where {T} =
isconcretetype(T) ? Union{T, Missing} : Any
promote_typejoin(::Type{T}, ::Type{Missing}) where {T} =
isconcretetype(T) ? Union{T, Missing} : Any
promote_typejoin(::Type{Nothing}, ::Type{Missing}) = Union{Nothing, Missing}
promote_typejoin(::Type{Missing}, ::Type{Nothing}) = Union{Nothing, Missing}
promote_typejoin(::Type{Nothing}, ::Type{Nothing}) = Nothing
promote_typejoin(::Type{Missing}, ::Type{Missing}) = Missing

# Returns length, isfixed
function full_va_len(p)
isempty(p) && return 0, true
Expand All @@ -116,14 +143,14 @@ function full_va_len(p)
return length(p)::Int, true
end

# reduce typejoin over A[i:end]
function tailjoin(A, i)
# reduce join_types over A[i:end]
function tailjoin(A, i, f::Function)
if i > length(A)
return unwrapva(A[end])
end
t = Bottom
for j = i:length(A)
t = typejoin(t, unwrapva(A[j]))
t = f(t, unwrapva(A[j]))
end
return t
end
Expand Down Expand Up @@ -193,6 +220,10 @@ it for new types as appropriate.
function promote_rule end

promote_rule(::Type{<:Any}, ::Type{<:Any}) = Bottom
# To fix ambiguities
promote_rule(::Type{Any}, ::Type{<:Any}) = Any
promote_rule(::Type{<:Any}, ::Type{Any}) = Any
promote_rule(::Type{Any}, ::Type{Any}) = Any

promote_result(::Type{<:Any},::Type{<:Any},::Type{T},::Type{S}) where {T,S} = (@_inline_meta; promote_type(T,S))
# If no promote_rule is defined, both directions give Bottom. In that
Expand Down
2 changes: 1 addition & 1 deletion base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ broadcast(::typeof(\), x::Number, r::LinSpace) = LinSpace(x \ r.start, x \
el_same(::Type{T}, a::Type{<:AbstractArray{T,n}}, b::Type{<:AbstractArray{T,n}}) where {T,n} = a
el_same(::Type{T}, a::Type{<:AbstractArray{T,n}}, b::Type{<:AbstractArray{S,n}}) where {T,S,n} = a
el_same(::Type{T}, a::Type{<:AbstractArray{S,n}}, b::Type{<:AbstractArray{T,n}}) where {T,S,n} = b
el_same(::Type, a, b) = typejoin(a, b)
el_same(::Type, a, b) = promote_typejoin(a, b)

promote_rule(a::Type{UnitRange{T1}}, b::Type{UnitRange{T2}}) where {T1,T2} =
el_same(promote_type(T1,T2), a, b)
Expand Down
2 changes: 1 addition & 1 deletion base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ _unique_from(itr, out, seen, i) = unique_from(itr, out, seen, i)
x, i = next(itr, i)
S = typeof(x)
if !(S === T || S <: T)
R = typejoin(S, T)
R = promote_typejoin(S, T)
seenR = convert(Set{R}, seen)
outR = convert(Vector{R}, out)
if !in(x, seenR)
Expand Down
4 changes: 2 additions & 2 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ end
eltype(t::Type{<:Tuple}) = _compute_eltype(t)
function _compute_eltype(t::Type{<:Tuple})
@_pure_meta
t isa Union && return typejoin(eltype(t.a), eltype(t.b))
t isa Union && return promote_typejoin(eltype(t.a), eltype(t.b))
= unwrap_unionall(t)
r = Union{}
for ti in.parameters
r = typejoin(r, rewrap_unionall(unwrapva(ti), t))
r = promote_typejoin(r, rewrap_unionall(unwrapva(ti), t))
end
return r
end
Expand Down
2 changes: 1 addition & 1 deletion base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ eltype(::Type{TwicePrecision{T}}) where {T} = T

promote_rule(::Type{TwicePrecision{R}}, ::Type{TwicePrecision{S}}) where {R,S} =
TwicePrecision{promote_type(R,S)}
promote_rule(::Type{TwicePrecision{R}}, ::Type{S}) where {R,S} =
promote_rule(::Type{TwicePrecision{R}}, ::Type{S}) where {R,S<:Number} =
TwicePrecision{promote_type(R,S)}

(::Type{T})(x::TwicePrecision) where {T<:Number} = T(x.hi + x.lo)::T
Expand Down
1 change: 1 addition & 0 deletions test/ambiguous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ end
pop!(need_to_handle_undef_sparam, which(Base.cat, (Any, SparseArrays._TypedDenseConcatGroup{T} where T)))
pop!(need_to_handle_undef_sparam, which(Base.float, Tuple{AbstractArray{Union{Missing, T},N} where {T, N}}))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing, T}} where T, Any}))
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Nothing, S}} where S, Type{T} where T}))
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing, S}} where S, Type{T} where T}))
pop!(need_to_handle_undef_sparam, which(Base.zero, Tuple{Type{Union{Missing, T}} where T}))
pop!(need_to_handle_undef_sparam, which(Base.one, Tuple{Type{Union{Missing, T}} where T}))
Expand Down
3 changes: 3 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,9 @@ end
@test isequal([1,2,3], [a for (a,b) in enumerate(2:4)])
@test isequal([2,3,4], [b for (a,b) in enumerate(2:4)])

@test [s for s in Union{String, Nothing}["a", nothing]] isa Vector{Union{String, Nothing}}
@test [s for s in Union{String, Missing}["a", missing]] isa Vector{Union{String, Missing}}

@testset "comprehension in let-bound function" begin
let xy = sum([x[i]*y[i] for i=1:length(x)])
@test [1,2] [3,4] == 11
Expand Down
13 changes: 13 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ end
@test typejoin(Tuple{Vararg{Int,2}}, Tuple{Int,Int,Int}) === Tuple{Int,Int,Vararg{Int}}
@test typejoin(Tuple{Vararg{Int,2}}, Tuple{Vararg{Int}}) === Tuple{Vararg{Int}}

# promote_typejoin returns a Union only with Nothing/Missing combined with concrete types
for T in (Nothing, Missing)
@test Base.promote_typejoin(Int, Float64) === Real
@test Base.promote_typejoin(Int, T) === Union{Int, T}
@test Base.promote_typejoin(T, String) === Union{T, String}
@test Base.promote_typejoin(Vector{Int}, T) === Union{Vector{Int}, T}
@test Base.promote_typejoin(Vector, T) === Any
@test Base.promote_typejoin(Real, T) === Any
@test Base.promote_typejoin(Int, String) === Any
@test Base.promote_typejoin(Int, Union{Float64, T}) === Any
@test Base.promote_typejoin(Int, Union{String, T}) === Any
end

@test promote_type(Bool,Bottom) === Bool

# type declarations
Expand Down
27 changes: 27 additions & 0 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,30 @@ end
for n = 0:5:100-q-d
for p = 100-q-d-n
if p < n < d < q] == [(50,30,15,5), (50,30,20,0), (50,40,10,0), (75,20,5,0)]

@testset "map/collect return type on generators with $T" for T in (Nothing, Missing)
x = ["a", "b"]
res = @inferred collect(s for s in x)
@test res isa Vector{String}
res = @inferred map(identity, x)
@test res isa Vector{String}
res = @inferred collect(s isa T for s in x)
@test res isa Vector{Bool}
res = @inferred map(s -> s isa T, x)
@test res isa Vector{Bool}
y = Union{String, T}["a", T()]
f(s::Union{Nothing, Missing}) = s
f(s::String) = s == "a"
res = collect(s for s in y)
@test res isa Vector{Union{String, T}}
res = map(identity, y)
@test res isa Vector{Union{String, T}}
res = @inferred collect(s isa T for s in y)
@test res isa Vector{Bool}
res = @inferred map(s -> s isa T, y)
@test res isa Vector{Bool}
res = collect(f(s) for s in y)
@test res isa Vector{Union{Bool, T}}
res = map(f, y)
@test res isa Vector{Union{Bool, T}}
end
10 changes: 10 additions & 0 deletions test/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ end
@test_broken promote_type(Union{Nothing, Missing, Int}, Float64) == Any
end

@testset "promotion in various contexts" for T in (Nothing, Missing)
@test collect(v for v in (1, T())) isa Vector{Union{Int,T}}
@test map(identity, Any[1, T()]) isa Vector{Union{Int,T}}
@test broadcast(identity, Any[1, T()]) isa Vector{Union{Int,T}}
@test unique((1, T())) isa Vector{Union{Int,T}}

@test map(ismissing, Any[1, missing]) isa Vector{Bool}
@test broadcast(ismissing, Any[1, missing]) isa BitVector
end

@testset "comparison operators" begin
@test (missing == missing) === missing
@test (1 == missing) === missing
Expand Down
11 changes: 11 additions & 0 deletions test/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,14 @@ abstr_nt_22194_3()
@test findlast(equalto(1), ()) === nothing
@test findfirst(equalto(1), (a=2, b=3)) === nothing
@test findlast(equalto(1), (a=2, b=3)) === nothing

# Test map with Nothing and Missing
for T in (Nothing, Missing)
x = [(a=1, b=T()), (a=1, b=2)]
y = map(v -> (a=v.a, b=v.b), [(a=1, b=T()), (a=1, b=2)])
@test y isa Vector{NamedTuple{(:a,:b),Tuple{Int,Union{T,Int}}}}
@test isequal(x, y)
end
y = map(v -> (a=v.a, b=v.a + v.b), [(a=1, b=missing), (a=1, b=2)])
@test y isa Vector{NamedTuple{(:a,:b),Tuple{Int,Union{Missing,Int}}}}
@test isequal(y, [(a=1, b=missing), (a=1, b=3)])
14 changes: 14 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ end
typejoin(Int, AbstractFloat, Bool)
@test eltype(Union{Tuple{Int, Float64}, Tuple{Vararg{Bool}}}) ===
typejoin(Int, Float64, Bool)
@test eltype(Tuple{Int, Missing}) === Union{Missing, Int}
@test eltype(Tuple{Int, Nothing}) === Union{Nothing, Int}
end

@testset "map with Nothing and Missing" begin
for T in (Nothing, Missing)
x = [(1, T()), (1, 2)]
y = map(v -> (v[1], v[2]), [(1, T()), (1, 2)])
@test y isa Vector{Tuple{Int,Union{T,Int}}}
@test isequal(x, y)
end
y = map(v -> (v[1], v[1] + v[2]), [(1, missing), (1, 2)])
@test y isa Vector{Tuple{Int,Union{Missing,Int}}}
@test isequal(y, [(1, missing), (1, 3)])
end

@testset "mapping" begin
Expand Down

0 comments on commit f58d2cf

Please sign in to comment.