Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make it easier to extend broadcast! #24992

Merged
merged 1 commit into from
Dec 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,6 @@ broadcast_indices
# special cases defined for performance
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)

## logic for deciding the BroadcastStyle
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
Expand All @@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
# combine_styles operates on values (arbitrarily many)
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))

# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
result_style(s::BroadcastStyle) = s
Expand Down Expand Up @@ -397,6 +391,7 @@ Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) =
result = @ncall $nargs f val
@inbounds B[I] = result
end
return B
end
end

Expand Down Expand Up @@ -433,6 +428,7 @@ end
@inbounds C[ind:bitcache_size] = false
dumpbitcache(Bc, cind, C)
end
return B
end
end

Expand All @@ -445,11 +441,38 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
_broadcast!(f, C, A, Bs...)
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)

# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanKarpinski, thanks again for seeing the Void -> Nothing change through! So much more intuitive :).

if f isa typeof(identity) && N == 1
A = As[1]
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
end
_broadcast!(f, dest, As...)
return dest
end

# Optimization for the all-Scalar case.
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
if dest isa AbstractArray
if f isa typeof(identity) && N == 1
return fill!(dest, As[1])
else
@inbounds for I in eachindex(dest)
dest[I] = f(As...)
end
return dest
end
end
_broadcast!(f, dest, As...)
return dest
end

# This indirection allows size-dependent implementations (e.g., see the copying `identity`
# specialization above)
# This indirection allows size-dependent implementations.
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
shape = broadcast_indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
Expand Down Expand Up @@ -630,7 +653,7 @@ function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::In
dest = Base.similar(Array{typeof(val)}, shape)
end
dest[I] = val
return _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
_broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
end

broadcast(f, ::Union{Scalar,Unknown}, ::Nothing, ::Nothing, a...) = f(a...)
Expand Down
44 changes: 20 additions & 24 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf

@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
isempty(C) && return _finishempty!(C)
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
Expand All @@ -106,14 +107,7 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
end
return C
end
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turned this into a spbroadcast_args! method.

_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
Expand Down Expand Up @@ -1006,28 +1000,30 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)

# ambiguity resolution
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
fill!(dest, x)
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
spbroadcast_args!(f, dest, SPVM, x...)

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
_aresameshape(dest, A, Bs...) && return _noshapecheck_map!(f, dest, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(dest), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
fpreszeros ? _broadcast_zeropres!(f, dest, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, dest, A, Bs...)
return dest
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
broadcast!(parevalf, dest, passedsrcargstup...)
return dest
end
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
end
function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
# Fallback. From a performance perspective would it be best to densify?
Broadcast._broadcast!(f, dest, mixedsrcargs...)
return dest
end

_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
Expand Down
36 changes: 36 additions & 0 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ perhaps range-types `Ind` of your own design. For more information, see [Arrays
| `broadcast(f, As...)` | Complete bypass of broadcasting machinery |
| `broadcast(f, ::DestStyle, ::Nothing, ::Nothing, As...)` | Bypass after container type is computed |
| `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed |
| `broadcast!(f, dest::DestType, ::Nothing, As...)` | Bypass in-place broadcast, specialization on destination type |
| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` |

[Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by
"dot" operations like `A .+ b`. Any `AbstractArray` type supports broadcasting,
Expand Down Expand Up @@ -591,3 +593,37 @@ yields another `SparseVecStyle`, that its combination with a 2-dimensional array
yields a `SparseMatStyle`, and anything of higher dimensionality falls back to the dense arbitrary-dimensional framework.
These rules allow broadcasting to keep the sparse representation for operations that result
in one or two dimensional outputs, but produce an `Array` for any other dimensionality.

### [Extending `broadcast!`](@id extending-in-place-broadcast)

Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce
ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions.

First, if you want to specialize on the destination type, say `DestType`, then you should
define a method with the following signature:

```julia
broadcast!(f, dest::DestType, ::Nothing, As...)
```

Note that no bounds should be placed on the types of `f` and `As...`.

Second, if specialized `broadcast!` behavior is desired depending on the input types,
you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the writing-binary-broadcasting-rules section?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following
signature:

```julia
broadcast!(f, dest, ::MyBroadcastStyle, As...)
```

Note the lack of bounds on `f`, `dest`, and `As...`.

Third, simultaneously specializing on both the type of `dest` and the `BroadcastStyle` is fine. In this case,
it is also allowed to specialize on the types of the source arguments (`As...`). For example, these method signatures are OK:

```julia
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As...)
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As::AbstractArray...)
broadcast!(f, dest::DestType, ::Broadcast.Scalar, As::Number...)
```