Skip to content

Commit

Permalink
[ci skip]
Browse files Browse the repository at this point in the history
Fix #24914 (WIP).

Address comments.

Reimplement and generalize all-scalar optimization.

Fix allocation tests for sparse broadcast!.

Revert back to calling Broadcast.combine_styles twice.

Fix more allocation tests in higherorderfns.jl (by @timholy).
  • Loading branch information
tkoolen committed Dec 17, 2017
1 parent a5c9e88 commit 26769f7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
41 changes: 30 additions & 11 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,6 @@ broadcast_indices
# special cases defined for performance
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)

## logic for deciding the BroadcastStyle
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
Expand All @@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
# combine_styles operates on values (arbitrarily many)
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))

# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
result_style(s::BroadcastStyle) = s
Expand Down Expand Up @@ -445,11 +439,36 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
_broadcast!(f, C, A, Bs...)
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)

# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N
if f isa typeof(identity) && N == 1
A = As[1]
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copy!(dest, A)
end
end
return _broadcast!(f, dest, As...)
end

# Optimization for the all-Scalar case.
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
if dest isa AbstractArray
if f isa typeof(identity) && N == 1
return fill!(dest, As[1])
else
@inbounds for I in eachindex(dest)
dest[I] = f(As...)
end
return dest
end
end
return _broadcast!(f, dest, As...)
end

# This indirection allows size-dependent implementations (e.g., see the copying `identity`
# specialization above)
# This indirection allows size-dependent implementations.
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
shape = broadcast_indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
Expand Down
43 changes: 24 additions & 19 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf

@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
isempty(C) && return _finishempty!(C)
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
Expand All @@ -106,14 +107,18 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
end
return C
end
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
@inline function broadcast!(f::Tf, dest::SparseVecOrMat, ::Void, As::Vararg{Any,N}) where {Tf,N}
if f isa typeof(identity) && N == 1
A = As[1]
if A isa Number
return fill!(dest, A)
elseif A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copy!(dest, A)
end
end
return spbroadcast_args!(f, dest, Broadcast.combine_styles(As...), As...)
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
Expand Down Expand Up @@ -1006,26 +1011,26 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)

# ambiguity resolution
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
fill!(dest, x)
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
spbroadcast_args!(f, dest, SPVM, x...)

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
function spbroadcast_args!(f::Tf, C, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
function spbroadcast_args!(f::Tf, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
end
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
function spbroadcast_args!(f::Tf, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
end
function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
function spbroadcast_args!(f::Tf, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# Fallback. From a performance perspective would it be best to densify?
Broadcast._broadcast!(f, dest, mixedsrcargs...)
end
Expand Down

0 comments on commit 26769f7

Please sign in to comment.