Skip to content

Commit

Permalink
Fix #24914.
Browse files Browse the repository at this point in the history
Reimplement and generalize all-scalar optimization.

Add documentation.

Explicitly return dest in various broadcast!-related methods. This is to make things easier on inference. Found by @timholy.

Collapse spbroadcast_args! into broadcast! as suggested by @Sacha0.
  • Loading branch information
tkoolen committed Dec 23, 2017
1 parent 6c94bd8 commit fa3fe32
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 36 deletions.
47 changes: 35 additions & 12 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,6 @@ broadcast_indices
# special cases defined for performance
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)

## logic for deciding the BroadcastStyle
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
Expand All @@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
# combine_styles operates on values (arbitrarily many)
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))

# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
result_style(s::BroadcastStyle) = s
Expand Down Expand Up @@ -397,6 +391,7 @@ Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) =
result = @ncall $nargs f val
@inbounds B[I] = result
end
return B
end
end

Expand Down Expand Up @@ -433,6 +428,7 @@ end
@inbounds C[ind:bitcache_size] = false
dumpbitcache(Bc, cind, C)
end
return B
end
end

Expand All @@ -445,11 +441,38 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
_broadcast!(f, C, A, Bs...)
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)

# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N
if f isa typeof(identity) && N == 1
A = As[1]
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
end
_broadcast!(f, dest, As...)
return dest
end

# Optimization for the all-Scalar case.
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
if dest isa AbstractArray
if f isa typeof(identity) && N == 1
return fill!(dest, As[1])
else
@inbounds for I in eachindex(dest)
dest[I] = f(As...)
end
return dest
end
end
_broadcast!(f, dest, As...)
return dest
end

# This indirection allows size-dependent implementations (e.g., see the copying `identity`
# specialization above)
# This indirection allows size-dependent implementations.
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
shape = broadcast_indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
Expand Down Expand Up @@ -630,7 +653,7 @@ function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::In
dest = Base.similar(Array{typeof(val)}, shape)
end
dest[I] = val
return _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
_broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
end

broadcast(f, ::Union{Scalar,Unknown}, ::Nothing, ::Nothing, a...) = f(a...)
Expand Down
44 changes: 20 additions & 24 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf

@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
isempty(C) && return _finishempty!(C)
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
Expand All @@ -106,14 +107,7 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
end
return C
end
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
Expand Down Expand Up @@ -1006,28 +1000,30 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)

# ambiguity resolution
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
fill!(dest, x)
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
spbroadcast_args!(f, dest, SPVM, x...)

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
_aresameshape(dest, A, Bs...) && return _noshapecheck_map!(f, dest, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(dest), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
fpreszeros ? _broadcast_zeropres!(f, dest, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, dest, A, Bs...)
return dest
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
broadcast!(parevalf, dest, passedsrcargstup...)
return dest
end
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
end
function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
# Fallback. From a performance perspective would it be best to densify?
Broadcast._broadcast!(f, dest, mixedsrcargs...)
return dest
end

_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
Expand Down
36 changes: 36 additions & 0 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ perhaps range-types `Ind` of your own design. For more information, see [Arrays
| `broadcast(f, As...)` | Complete bypass of broadcasting machinery |
| `broadcast(f, ::DestStyle, ::Nothing, ::Nothing, As...)` | Bypass after container type is computed |
| `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed |
| `broadcast!(f, dest::DestType, ::Nothing, As...)` | Bypass in-place broadcast, specialization on destination type |
| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` |

[Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by
"dot" operations like `A .+ b`. Any `AbstractArray` type supports broadcasting,
Expand Down Expand Up @@ -591,3 +593,37 @@ yields another `SparseVecStyle`, that its combination with a 2-dimensional array
yields a `SparseMatStyle`, and anything of higher dimensionality falls back to the dense arbitrary-dimensional framework.
These rules allow broadcasting to keep the sparse representation for operations that result
in one or two dimensional outputs, but produce an `Array` for any other dimensionality.

### [Extending `broadcast!`](@id extending-in-place-broadcast)

Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce
ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions.

First, if you want to specialize on the destination type, say `DestType`, then you should
define a method with the following signature:

```julia
broadcast!(f, dest::DestType, ::Nothing, As...)
```

Note that no bounds should be placed on the types of `f` and `As...`.

Second, if specialized `broadcast!` behavior is desired depending on the input types,
you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to
determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following
signature:

```julia
broadcast!(f, dest, ::MyBroadcastStyle, As...)
```

Note the lack of bounds on `f`, `dest`, and `As...`.

Third, simultaneously specializing on both the type of `dest` and the `BroadcastStyle` is fine. In this case,
it is also allowed to specialize on the types of the source arguments (`As...`). For example, these method signatures are OK:

```julia
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As...)
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As::AbstractArray...)
broadcast!(f, dest::DestType, ::Broadcast.Scalar, As::Number...)
```

0 comments on commit fa3fe32

Please sign in to comment.