From d0c0c14c54440cf361bb820dd921d1cc40b28558 Mon Sep 17 00:00:00 2001 From: Twan Koolen Date: Fri, 8 Dec 2017 14:09:37 -0500 Subject: [PATCH] Fix #24914. Reimplement and generalize all-scalar optimization. Add documentation. Explicitly return dest in various broadcast!-related methods. This is to make things easier on inference. Found by @timholy. --- base/broadcast.jl | 47 ++++++++++++++++++++++++++--------- base/sparse/higherorderfns.jl | 44 ++++++++++++++++++-------------- doc/src/manual/interfaces.md | 27 ++++++++++++++++++++ 3 files changed, 87 insertions(+), 31 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 603f8ec54efc29..7e8b2a0602e58e 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -239,12 +239,6 @@ broadcast_indices # special cases defined for performance broadcast(f, x::Number...) = f(x...) @inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...) -@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} = - Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y) - -# special cases for "X .= ..." (broadcast!) assignments -broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x) -broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X) ## logic for deciding the BroadcastStyle # Dimensionality: computing max(M,N) in the type domain so we preserve inferrability @@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = () # combine_styles operates on values (arbitrarily many) combine_styles(c) = result_style(BroadcastStyle(typeof(c))) combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2)) -combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) +@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) # result_style works on types (singletons and pairs), and leverages `BroadcastStyle` result_style(s::BroadcastStyle) = s @@ -397,6 +391,7 @@ Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) = result = @ncall $nargs f val @inbounds B[I] = result end + return B end end @@ -433,6 +428,7 @@ end @inbounds C[ind:bitcache_size] = false dumpbitcache(Bc, cind, C) end + return B end end @@ -445,11 +441,38 @@ Note that `dest` is only used to store the result, and does not supply arguments to `f` unless it is also listed in the `As`, as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. """ -@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} = - _broadcast!(f, C, A, Bs...) +@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...) +@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...) + +# Default behavior (separated out so that it can be called by users who want to extend broadcast!). +@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N + if f isa typeof(identity) && N == 1 + A = As[1] + if A isa AbstractArray && Base.axes(dest) == Base.axes(A) + return copyto!(dest, A) + end + end + _broadcast!(f, dest, As...) + return dest +end + +# Optimization for the all-Scalar case. +@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N + if dest isa AbstractArray + if f isa typeof(identity) && N == 1 + return fill!(dest, As[1]) + else + @inbounds for I in eachindex(dest) + dest[I] = f(As...) + end + return dest + end + end + _broadcast!(f, dest, As...) + return dest +end -# This indirection allows size-dependent implementations (e.g., see the copying `identity` -# specialization above) +# This indirection allows size-dependent implementations. @inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N shape = broadcast_indices(C) @boundscheck check_broadcast_indices(shape, A, Bs...) @@ -630,7 +653,7 @@ function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::In dest = Base.similar(Array{typeof(val)}, shape) end dest[I] = val - return _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1) + _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1) end broadcast(f, ::Union{Scalar,Unknown}, ::Void, ::Void, a...) = f(a...) diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 4f90b7bc6a4304..5ccd57e02a2f27 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -93,7 +93,8 @@ end # (3) broadcast[!] entry points broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A) broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) -function broadcast!(f::Tf, C::SparseVecOrMat) where Tf + +@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf isempty(C) && return _finishempty!(C) fofnoargs = f() if _iszero(fofnoargs) # f() is zero, so empty C @@ -106,14 +107,19 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf end return C end -function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} - _aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...) - Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...) - fofzeros = f(_zeros_eltypes(A, Bs...)...) - fpreszeros = _iszero(fofzeros) - return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) : - _broadcast_notzeropres!(f, fofzeros, C, A, Bs...) +@inline function broadcast!(f::Tf, dest::SparseVecOrMat, ::Void, As::Vararg{Any,N}) where {Tf,N} + if f isa typeof(identity) && N == 1 + A = As[1] + if A isa Number + return fill!(dest, A) + elseif A isa AbstractArray && Base.axes(dest) == Base.axes(A) + return copyto!(dest, A) + end + end + spbroadcast_args!(f, dest, Broadcast.combine_styles(As...), As...) + return dest end + # the following three similar defs are necessary for type stability in the mixed vector/matrix case broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} = _aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...) @@ -1006,26 +1012,26 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} = broadcast(f, map(_sparsifystructured, As)...) -# ambiguity resolution -broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) = - fill!(dest, x) -broadcast!(f, dest::SparseVecOrMat, x::Number...) = - spbroadcast_args!(f, dest, SPVM, x...) - # For broadcast! with ::Any inputs, we need a layer of indirection to determine whether # the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars, # we can handle it here, otherwise see below for the promotion machinery. -broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N = - spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...) -function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N +function spbroadcast_args!(f::Tf, C, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} + _aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...) + Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...) + fofzeros = f(_zeros_eltypes(A, Bs...)...) + fpreszeros = _iszero(fofzeros) + return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) : + _broadcast_notzeropres!(f, fofzeros, C, A, Bs...) +end +function spbroadcast_args!(f::Tf, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N} # mixedsrcargs contains nothing but SparseVecOrMat and scalars parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs) return broadcast!(parevalf, dest, passedsrcargstup...) end -function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N +function spbroadcast_args!(f::Tf, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N} broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...) end -function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N +function spbroadcast_args!(f::Tf, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where {Tf,N} # Fallback. From a performance perspective would it be best to densify? Broadcast._broadcast!(f, dest, mixedsrcargs...) end diff --git a/doc/src/manual/interfaces.md b/doc/src/manual/interfaces.md index e8498d31a8e97d..d206c7d4361be1 100644 --- a/doc/src/manual/interfaces.md +++ b/doc/src/manual/interfaces.md @@ -404,6 +404,8 @@ perhaps range-types `Ind` of your own design. For more information, see [Arrays | `broadcast(f, As...)` | Complete bypass of broadcasting machinery | | `broadcast(f, ::DestStyle, ::Void, ::Void, As...)` | Bypass after container type is computed | | `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed | +| `broadcast!(f, dest::DestType, ::Void, As...)` | Bypass in-place broadcast, specialization on destination type | +| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` | [Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by "dot" operations like `A .+ b`. Any `AbstractArray` type supports broadcasting, @@ -591,3 +593,28 @@ yields another `SparseVecStyle`, that its combination with a 2-dimensional array yields a `SparseMatStyle`, and anything of higher dimensionality falls back to the dense arbitrary-dimensional framework. These rules allow broadcasting to keep the sparse representation for operations that result in one or two dimensional outputs, but produce an `Array` for any other dimensionality. + +### [Extending `broadcast!`](@id extending-in-place-broadcast) + +Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce +ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions. + +First, if you want to specialize on the destination type, say `DestType`, then you should +define a method with the following signature: + +```julia +broadcast!(f, dest::DestType, ::Void, As...) +``` + +Note that no bounds should be placed on the types of `f` and `As...`. + +Second, if specialized `broadcast!` behavior is desired depending on the input types, +you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to +determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following +signature: + +```julia +broadcast!(f, dest, ::MyBroadcastStyle, As...) +``` + +Note the lack of bounds on `f`, `dest`, and `As...`.