From dbe1ae0a7e5a53a3afc92c5edb5876c522506258 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Mon, 23 Apr 2018 19:56:46 -0400 Subject: [PATCH 1/7] Customizable lazy fused broadcasting in pure Julia MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch represents the combined efforts of four individuals, over 60 commits, and an iterated design over (at least) three pull requests that spanned nearly an entire year (closes #22063, #23692, #25377 by superceding them). This introduces a pure Julia data structure that represents a fused broadcast expression. For example, the expression `2 .* (x .+ 1)` lowers to: ```julia julia> Meta.@lower 2 .* (x .+ 1) :($(Expr(:thunk, CodeInfo(:(begin Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize) Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1) Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3)) Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4)) return Core.SSAValue(5) end))))) ``` Or, slightly more readably as: ```julia using .Broadcast: materialize, make materialize(make(*, 2, make(+, x, 1))) ``` The `Broadcast.make` function serves two purposes. Its primary purpose is to construct the `Broadcast.Broadcasted` objects that hold onto the function, the tuple of arguments (potentially including nested `Broadcasted` arguments), and sometimes a set of `axes` to include knowledge of the outer shape. The secondary purpose, however, is to allow an "out" for objects that _don't_ want to participate in fusion. For example, if `x` is a range in the above `2 .* (x .+ 1)` expression, it needn't allocate an array and operate elementwise — it can just compute and return a new range. Thus custom structures are able to specialize `Broadcast.make(f, args...)` just as they'd specialize on `f` normally to return an immediate result. `Broadcast.materialize` is identity for everything _except_ `Broadcasted` objects for which it allocates an appropriate result and computes the broadcast. It does two things: it `initialize`s the outermost `Broadcasted` object to compute its axes and then `copy`s it. Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact same expression tree to compute the right-hand side of the expression as above, and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the `Broadcasted` expression tree and then `copyto!` it into the given destination. All-together, this forms a complete API for custom types to extend and customize the behavior of broadcast (fixes #22060). It uses the existing `BroadcastStyle`s throughout to simplify dispatch on many arguments: * Custom types can opt-out of broadcast fusion by specializing `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`. * The `Broadcasted` object computes and stores the type of the combined `BroadcastStyle` of its arguments as its first type parameter, allowing for easy dispatch and specialization. * Custom Broadcast storage is still allocated via `broadcast_similar`, however instead of passing just a function as a first argument, the entire `Broadcasted` object is passed as a final argument. This potentially allows for much more runtime specialization dependent upon the exact expression given. * Custom broadcast implmentations for a `CustomStyle` are defined by specializing `copy(bc::Broadcasted{CustomStyle})` or `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`. * Fallback broadcast specializations for a given output object of type `Dest` (for the `DefaultArrayStyle` or another such style that hasn't implemented assignments into such an object) are defined by specializing `copyto(dest::Dest, bc::Broadcasted{Nothing})`. As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to `.+`, just as had been done for all `AbstractArray`s in general. As a first-mover proof of concept, LinearAlgebra uses this new system to improve broadcasting over structured arrays. Before, broadcasting over a structured matrix would result in a sparse array. Now, broadcasting over a structured matrix will _either_ return an appropriately structured matrix _or_ a dense array. This does incur a type instability (in the form of a discriminated union) in some situations, but thanks to type-based introspection of the `Broadcasted` wrapper commonly used functions can be special cased to be type stable. For example: ```julia julia> f(d) = round.(Int, d) f (generic function with 1 method) julia> @inferred f(Diagonal(rand(3))) 3×3 Diagonal{Int64,Array{Int64,1}}: 0 ⋅ ⋅ ⋅ 0 ⋅ ⋅ ⋅ 1 julia> @inferred Diagonal(rand(3)) .* 3 ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}} Stacktrace: [1] error(::String) at ./error.jl:33 [2] top-level scope julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3) 4×4 Tridiagonal{Float64,Array{Float64,1}}: 1.30771 0.838589 ⋅ ⋅ 0.0 3.89109 0.0459757 ⋅ ⋅ 0.0 4.48033 2.51508 ⋅ ⋅ 0.0 6.23739 ``` In addition to the issues referenced above, it fixes: * Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated specially in a fused broadcast; they're just arguments in a `Broadcasted` object like everything else. * Fixes #21094: Since broadcasting is now represented by a pure Julia datastructure it can be created within `@generated` functions and serialized. * Fixes #26097: The fallback destination-array specialization method of `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not be confused by `nothing` arguments. * Fixes the broadcast-specific element of #25499: The default base broadcast implementation no longer depends upon `Base._return_type` to allocate its array (except in the empty or concretely-type cases). Note that the sparse implementation (#19595) is still dependent upon inference and is _not_ fixed. * Fixes #25340: Functions are treated like normal values just like arguments and only evaluated once. * Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was fixed on master already, but this fixes it now, too. * Fixes #25521. * The performance of this patch has been thoroughly tested through its iterative development process in #25377. There remain [two classes of performance regressions](#25377) that Nanosoldier flagged. * #25691: Propagation of constant literals sill lose their constant-ness upon going through the broadcast machinery. I believe quite a large number of functions would need to be marked as `@pure` to support this -- including functions that are intended to be specialized. (For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](https://github.com/JuliaLang/julia/pull/25377) branch as of a1d4e7ec9756ada74fb48f2c514615b9d981cf5c. Squashed and separated out to make it easier to review and commit) Co-authored-by: Tim Holy Co-authored-by: Jameson Nash Co-authored-by: Andrew Keller --- NEWS.md | 13 +- base/bitarray.jl | 40 - base/broadcast.jl | 906 ++++++++++++------ base/compiler/ssair/inlining2.jl | 8 +- base/compiler/ssair/slot2ssa.jl | 4 +- base/deprecated.jl | 4 + base/float.jl | 10 - base/range.jl | 66 -- base/reducedim.jl | 4 +- base/sort.jl | 7 +- base/statistics.jl | 2 +- doc/src/base/arrays.md | 2 +- doc/src/manual/interfaces.md | 173 ++-- src/julia-syntax.scm | 123 +-- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 3 + stdlib/LinearAlgebra/src/bidiag.jl | 11 - stdlib/LinearAlgebra/src/diagonal.jl | 1 - .../LinearAlgebra/src/structuredbroadcast.jl | 180 ++++ stdlib/LinearAlgebra/src/triangular.jl | 3 - stdlib/LinearAlgebra/src/tridiag.jl | 23 +- stdlib/LinearAlgebra/src/uniformscaling.jl | 6 +- .../LinearAlgebra/test/structuredbroadcast.jl | 101 ++ stdlib/SparseArrays/src/higherorderfns.jl | 258 +++-- stdlib/SparseArrays/test/higherorderfns.jl | 80 +- test/bitarray.jl | 35 + test/broadcast.jl | 114 ++- test/core.jl | 4 +- test/numbers.jl | 3 +- test/ranges.jl | 87 +- 29 files changed, 1407 insertions(+), 864 deletions(-) create mode 100644 stdlib/LinearAlgebra/src/structuredbroadcast.jl create mode 100644 stdlib/LinearAlgebra/test/structuredbroadcast.jl diff --git a/NEWS.md b/NEWS.md index 93d446f42874e..0da63c9973120 100644 --- a/NEWS.md +++ b/NEWS.md @@ -388,11 +388,6 @@ This section lists changes that do not have deprecation warnings. Its return value has been removed. Use the `process_running` function to determine if a process has already exited. - * Broadcasting has been redesigned with an extensible public interface. The new API is - documented at https://docs.julialang.org/en/latest/manual/interfaces/#Interfaces-1. - `AbstractArray` types that specialized broadcasting using the old internal API will - need to switch to the new API. ([#20740]) - * The logging system has been redesigned - `info` and `warn` are deprecated and replaced with the logging macros `@info`, `@warn`, `@debug` and `@error`. The `logging` function is also deprecated and replaced with @@ -418,6 +413,14 @@ This section lists changes that do not have deprecation warnings. * `findn(x::AbstractArray)` has been deprecated in favor of `findall(!iszero, x)`, which now returns cartesian indices for multidimensional arrays (see below, [#25532]). + * Broadcasting operations are no longer fused into a single operation by Julia's parser. + Instead, a lazy `Broadcasted` wrapper is created, and the parser will call + `copy(bc::Broadcasted)` or `copyto!(dest, bc::Broadcasted)` + to evaluate the wrapper. Consequently, package authors generally need to specialize + `copy` and `copyto!` methods rather than `broadcast` and `broadcast!`. + See the [Interfaces chapter](https://docs.julialang.org/en/latest/manual/interfaces/#Interfaces-1) + for more information. + * `find` has been renamed to `findall`. `findall`, `findfirst`, `findlast`, `findnext` now take and/or return the same type of indices as `keys`/`pairs` for `AbstractArray`, `AbstractDict`, `AbstractString`, `Tuple` and `NamedTuple` objects ([#24774], [#25545]). diff --git a/base/bitarray.jl b/base/bitarray.jl index bac2ad07d6a79..898980b92d4ac 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -1097,19 +1097,6 @@ function (-)(B::BitArray) end broadcast(::typeof(sign), B::BitArray) = copy(B) -function broadcast(::typeof(~), B::BitArray) - C = similar(B) - Bc = B.chunks - if !isempty(Bc) - Cc = C.chunks - for i = 1:length(Bc) - Cc[i] = ~Bc[i] - end - Cc[end] &= _msk_end(B) - end - return C -end - """ flipbits!(B::BitArray{N}) -> BitArray{N} @@ -1166,33 +1153,6 @@ end (/)(B::BitArray, x::Number) = (/)(Array(B), x) (/)(x::Number, B::BitArray) = (/)(x, Array(B)) -# broadcast specializations for &, |, and xor/⊻ -broadcast(::typeof(&), B::BitArray, x::Bool) = x ? copy(B) : falses(size(B)) -broadcast(::typeof(&), x::Bool, B::BitArray) = broadcast(&, B, x) -broadcast(::typeof(|), B::BitArray, x::Bool) = x ? trues(size(B)) : copy(B) -broadcast(::typeof(|), x::Bool, B::BitArray) = broadcast(|, B, x) -broadcast(::typeof(xor), B::BitArray, x::Bool) = x ? .~B : copy(B) -broadcast(::typeof(xor), x::Bool, B::BitArray) = broadcast(xor, B, x) -for f in (:&, :|, :xor) - @eval begin - function broadcast(::typeof($f), A::BitArray, B::BitArray) - F = BitArray(undef, promote_shape(size(A),size(B))...) - Fc = F.chunks - Ac = A.chunks - Bc = B.chunks - (isempty(Ac) || isempty(Bc)) && return F - for i = 1:length(Fc) - Fc[i] = ($f)(Ac[i], Bc[i]) - end - Fc[end] &= _msk_end(F) - return F - end - broadcast(::typeof($f), A::DenseArray{Bool}, B::BitArray) = broadcast($f, BitArray(A), B) - broadcast(::typeof($f), B::BitArray, A::DenseArray{Bool}) = broadcast($f, B, BitArray(A)) - end -end - - ## promotion to complex ## # TODO? diff --git a/base/broadcast.jl b/base/broadcast.jl index 55dd8313172b0..ef81b60f89f4b 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -3,11 +3,10 @@ module Broadcast using .Base.Cartesian -using .Base: Indices, OneTo, linearindices, tail, to_shape, - _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, - isoperator, promote_typejoin, unalias -import .Base: broadcast, broadcast! -export BroadcastStyle, broadcast_indices, broadcast_similar, broadcastable, +using .Base: Indices, OneTo, linearindices, tail, to_shape, isoperator, promote_typejoin, + _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias +import .Base: broadcast, broadcast!, copy, copyto! +export BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable, broadcast_getindex, broadcast_setindex!, dotview, @__dot__ ### Objects with customized broadcasting behavior should declare a BroadcastStyle @@ -149,47 +148,223 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = typeof(a)(_max(Val(M),Val(N))) +### Lazy-wrapper for broadcasting + +# `Broadcasted` wrap the arguments to `broadcast(f, args...)`. A statement like +# y = x .* (x .+ 1) +# will result in code that is essentially +# y = copy(Broadcasted(*, x, Broadcasted(+, x, 1))) +# `broadcast!` results in `copyto!(dest, Broadcasted(...))`. + +# The use of `Nothing` in place of a `BroadcastStyle` has a different +# application, in the fallback method +# copyto!(dest, bc::Broadcasted) = copyto!(dest, convert(Broadcasted{Nothing}, bc)) +# This allows methods +# copyto!(dest::DestType, bc::Broadcasted{Nothing}) +# that specialize on `DestType` to be easily disambiguated from +# methods that instead specialize on `BroadcastStyle`, +# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle}) + +struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} + f::F + args::Args + axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`) +end + +Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} = + Broadcasted{typeof(combine_styles(args...))}(f, args, axes) +function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple} + # using Core.Typeof rather than F preserves inferrability when f is a type + Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes) +end + +Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} = + Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes) + +Base.show(io::IO, bc::Broadcasted{Style}) where {Style} = print(io, Broadcasted, '{', Style, "}(", bc.f, ", ", bc.args, ')') + ## Allocating the output container """ - broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...) + broadcast_similar(::BroadcastStyle, ::Type{ElType}, inds, bc) Allocate an output object for [`broadcast`](@ref), appropriate for the indicated -[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and indices of the -container. -`f` is the broadcast operation, and `As...` are the arguments supplied to `broadcast`. +[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and axes of the +container. The final `bc` argument is the `Broadcasted` object representing the fused broadcast operation +and its arguments. """ -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, As...) where {N,ElType} = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} = similar(Array{ElType}, inds) -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{Bool}, inds::Indices{N}, As...) where N = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{Bool}, inds::Indices{N}, bc) where N = similar(BitArray, inds) # In cases of conflict we fall back on Array -broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) where ElType = +broadcast_similar(::ArrayConflict, ::Type{ElType}, inds::Indices, bc) where ElType = similar(Array{ElType}, inds) -broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) = +broadcast_similar(::ArrayConflict, ::Type{Bool}, inds::Indices, bc) = similar(BitArray, inds) -## Computing the result's indices. Most types probably won't need to specialize this. -broadcast_indices() = () -broadcast_indices(::Type{T}) where T = () -broadcast_indices(A) = broadcast_indices(combine_styles(A), A) -broadcast_indices(::Style{Tuple}, A) = (OneTo(length(A)),) -broadcast_indices(::DefaultArrayStyle{0}, A::Ref) = () -broadcast_indices(::BroadcastStyle, A) = Base.axes(A) +## Computing the result's axes. Most types probably won't need to specialize this. +broadcast_axes() = () +broadcast_axes(A::Tuple) = (OneTo(length(A)),) +broadcast_axes(A::Ref) = () +@inline broadcast_axes(A) = axes(A) """ - Base.broadcast_indices(::SrcStyle, A) + Base.broadcast_axes(A) -Compute the indices for objects `A` with [`BroadcastStyle`](@ref) `SrcStyle`. -If needed, you can specialize this method for your styles. -You should only need to provide a custom implementation for non-AbstractArrayStyles. +Compute the axes for `A`. + +This should only be specialized for objects that do not define axes but want to participate in broadcasting. """ -broadcast_indices +broadcast_axes ### End of methods that users will typically have to specialize ### +@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes) +_axes(::Broadcasted, axes::Tuple) = axes +@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...) +_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),) +_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = () + +BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style() +BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} = + throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned")) + +argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args +argtype(bc::Broadcasted) = argtype(typeof(bc)) + +const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}} +not_nested(bc::Broadcasted) = _not_nested(bc.args) +_not_nested(t::Tuple) = _not_nested(tail(t)) +_not_nested(::NestedTuple) = false +_not_nested(::Tuple{}) = true + +## Instantiation fills in the "missing" fields in Broadcasted. +instantiate(x) = x + +""" + Broadcast.instantiate(bc::Broadcasted) + +Construct the axes and indexing helpers for the lazy Broadcasted object `bc`. + +Custom `BroadcastStyle`s may override this default in cases where it is fast and easy +to compute the resulting `axes` and indexing helpers on-demand, leaving those fields +of the `Broadcasted` object empty (populated with `nothing`). If they do so, however, +they must provide their own `Base.axes(::Broadcasted{Style})` and +`Base.getindex(::Broadcasted{Style}, I::Union{Int,CartesianIndex})` methods as appropriate. +""" +@inline function instantiate(bc::Broadcasted{Style}) where {Style} + if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) + axes = combine_axes(bc.args...) + else + axes = bc.axes + check_broadcast_axes(axes, bc.args...) + end + return Broadcasted{Style}(bc.f, bc.args, axes) +end +instantiate(bc::Broadcasted{<:Union{AbstractArrayStyle{0}, Style{Tuple}}}) = bc + +## Flattening + +""" + bcf = flatten(bc) + +Create a "flat" representation of a lazy-broadcast operation. +From + f.(a, g.(b, c), d) +we produce the equivalent of + h.(a, b, c, d) +where + h(w, x, y, z) = f(w, g(x, y), z) +In terms of its internal representation, + Broadcasted(f, a, Broadcasted(g, b, c), d) +becomes + Broadcasted(h, a, b, c, d) + +This is an optional operation that may make custom implementation of broadcasting easier in +some cases. +""" +function flatten(bc::Broadcasted{Style}) where {Style} + isflat(bc.args) && return bc + # concatenate the nested arguments into {a, b, c, d} + args = cat_nested(x->x.args, bc) + # build a function `makeargs` that takes a "flat" argument list and + # and creates the appropriate input arguments for `f`, e.g., + # makeargs = (w, x, y, z) -> (w, g(x, y), z) + # + # `makeargs` is built recursively and looks a bit like this: + # makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...) + # = (w, g(x, y), makeargs2(z)...) + # = (w, g(x, y), z) + let makeargs = make_makeargs(bc) + newf = @inline function(args::Vararg{Any,N}) where N + bc.f(makeargs(args...)...) + end + return Broadcasted{Style}(newf, args, bc.axes) + end +end + +isflat(args::NestedTuple) = false +isflat(args::Tuple) = isflat(tail(args)) +isflat(args::Tuple{}) = true + +cat_nested(fieldextractor, bc::Broadcasted) = cat_nested(fieldextractor, fieldextractor(bc), ()) + +cat_nested(fieldextractor, t::Tuple, rest) = + (t[1], cat_nested(fieldextractor, tail(t), rest)...) +cat_nested(fieldextractor, t::Tuple{<:Broadcasted,Vararg{Any}}, rest) = + cat_nested(fieldextractor, cat_nested(fieldextractor, fieldextractor(t[1]), tail(t)), rest) +cat_nested(fieldextractor, t::Tuple{}, tail) = cat_nested(fieldextractor, tail, ()) +cat_nested(fieldextractor, t::Tuple{}, tail::Tuple{}) = () + +make_makeargs(bc::Broadcasted) = make_makeargs(()->(), bc.args) +@inline function make_makeargs(makeargs, t::Tuple) + let makeargs = make_makeargs(makeargs, tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + (head, makeargs(tail...)...) + end + end +end +@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) + bc = t[1] + let makeargs = make_makeargs(makeargs, tail(t)) + let makeargs = make_makeargs(makeargs, bc.args) + headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) + return @inline function(args::Vararg{Any,N}) where N + args1 = makeargs(args...) + a, b = headargs(args1...), tailargs(args1...) + (bc.f(a...), b...) + end + end + end +end +make_makeargs(makeargs, ::Tuple{}) = makeargs + +@inline function make_headargs(t::Tuple) + let headargs = make_headargs(tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + (head, headargs(tail...)...) + end + end +end +@inline function make_headargs(::Tuple{}) + return @inline function(tail::Vararg{Any,N}) where N + () + end +end + +@inline function make_tailargs(t::Tuple) + let tailargs = make_tailargs(tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + tailargs(tail...) + end + end +end +@inline function make_tailargs(::Tuple{}) + return @inline function(tail::Vararg{Any,N}) where N + tail + end +end + ## Broadcasting utilities ## -# special cases defined for performance -broadcast(f, x::Number...) = f(x...) -@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...) ## logic for deciding the BroadcastStyle # Dimensionality: computing max(M,N) in the type domain so we preserve inferrability @@ -204,6 +379,7 @@ longest(t1::Tuple, ::Tuple{}) = (true, longest(Base.tail(t1), ())...) longest(::Tuple{}, ::Tuple{}) = () # combine_styles operates on values (arbitrarily many) +combine_styles() = DefaultArrayStyle{0}() combine_styles(c) = result_style(BroadcastStyle(typeof(c))) combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2)) @inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) @@ -236,8 +412,8 @@ One of these should be undefined (and thus return Broadcast.Unknown).""") end # Indices utilities -combine_indices(A, B...) = broadcast_shape(broadcast_indices(A), combine_indices(B...)) -combine_indices(A) = broadcast_indices(A) +@inline combine_axes(A, B...) = broadcast_shape(broadcast_axes(A), combine_axes(B...)) +combine_axes(A) = broadcast_axes(A) # shape (i.e., tuple-of-indices) inputs broadcast_shape(shape::Tuple) = shape @@ -269,119 +445,124 @@ function check_broadcast_shape(shp, Ashp::Tuple) _bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination")) check_broadcast_shape(tail(shp), tail(Ashp)) end -check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A)) +check_broadcast_axes(shp, A) = check_broadcast_shape(shp, broadcast_axes(A)) # comparing many inputs -@inline function check_broadcast_indices(shp, A, As...) - check_broadcast_indices(shp, A) - check_broadcast_indices(shp, As...) +@inline function check_broadcast_axes(shp, A, As...) + check_broadcast_axes(shp, A) + check_broadcast_axes(shp, As...) end ## Indexing manipulations +""" + newindex(argument, I) + newindex(I, keep, default) + +Recompute index `I` such that it appropriately constrains broadcasted dimensions to the source. + +Two methods are supported, both allowing for `I` to be specified as either a `CartesianIndex` or +an `Int`. + +* `newindex(argument, I)` dynamically constrains `I` based upon the axes of `argument`. +* `newindex(I, keep, default)` constrains `I` using the pre-computed tuples `keeps` and `defaults`. + * `keep` is a tuple of `Bool`s, where `keep[d] == true` means that dimension `d` in `I` should be preserved as is + * `default` is a tuple of Integers, specifying what index to use in dimension `d` when `keep[d] == false`. + Any remaining indices in `I` beyond the length of the `keep` tuple are truncated. The `keep` and `default` + tuples may be created by `newindexer(argument)`. +""" +Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(broadcast_axes(arg), I.I)) +Base.@propagate_inbounds newindex(arg, I::Int) = CartesianIndex(_newindex(broadcast_axes(arg), (I,))) +Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(Base.unsafe_length(ax[1])==1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...) +Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = () +Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...) +Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = () -# newindex(I, keep, Idefault) replaces a CartesianIndex `I` with something that -# is appropriate for a particular broadcast array/scalar. `keep` is a -# NTuple{N,Bool}, where keep[d] == true means that one should preserve -# I[d]; if false, replace it with Idefault[d]. # If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`. @inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault)) +@inline newindex(i::Int, keep::Tuple{Bool}, idefault) = ifelse(keep[1], i, idefault[1]) @inline _newindex(I, keep, Idefault) = (ifelse(keep[1], I[1], Idefault[1]), _newindex(tail(I), tail(keep), tail(Idefault))...) @inline _newindex(I, keep::Tuple{}, Idefault) = () # truncate if keep is shorter than I -# newindexer(shape, A) generates `keep` and `Idefault` (for use by -# `newindex` above) for a particular array `A`, given the -# broadcast indices `shape` -# `keep` is equivalent to map(==, axes(A), shape) (but see #17126) -@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A)) -@inline shapeindexer(shape, indsA::Tuple{}) = (), () -@inline function shapeindexer(shape, indsA::Tuple) +# newindexer(A) generates `keep` and `Idefault` (for use by `newindex` above) +# for a particular array `A`; `shapeindexer` does so for its axes. +@inline newindexer(A) = shapeindexer(broadcast_axes(A)) +@inline shapeindexer(ax) = _newindexer(ax) +@inline _newindexer(indsA::Tuple{}) = (), () +@inline function _newindexer(indsA::Tuple) ind1 = indsA[1] - keep, Idefault = shapeindexer(tail(shape), tail(indsA)) - (shape[1] == ind1, keep...), (first(ind1), Idefault...) + keep, Idefault = _newindexer(tail(indsA)) + (length(ind1)!=1, keep...), (first(ind1), Idefault...) end -# Equivalent to map(x->newindexer(shape, x), As) (but see #17126) -map_newindexer(shape, ::Tuple{}) = (), () -@inline function map_newindexer(shape, As) - A1 = As[1] - keeps, Idefaults = map_newindexer(shape, tail(As)) - keep, Idefault = newindexer(shape, A1) - (keep, keeps...), (Idefault, Idefaults...) -end -@inline function map_newindexer(shape, A, Bs) - keeps, Idefaults = map_newindexer(shape, Bs) - keep, Idefault = newindexer(shape, A) - (keep, keeps...), (Idefault, Idefaults...) +@inline function Base.getindex(bc::Broadcasted, I) + @boundscheck checkbounds(bc, I) + @inbounds _broadcast_getindex(bc, I) end -Base.@propagate_inbounds _broadcast_getindex(::Type{T}, I) where T = T -Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(combine_styles(A), A, I) -Base.@propagate_inbounds _broadcast_getindex(::DefaultArrayStyle{0}, A, I) = A[] -Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I] -Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) = A[1] +@inline Base.checkbounds(bc::Broadcasted, I) = + Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) -## Broadcasting core -# nargs encodes the number of As arguments (which matches the number -# of keeps). The first two type parameters are to ensure specialization. -@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N} - nargs = N + 1 - quote - $(Expr(:meta, :inline)) - # destructure the keeps and As tuples - A_1 = A - @nexprs $N i->(A_{i+1} = Bs[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - @simd for I in iter - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function and store the result - result = @ncall $nargs f val - @inbounds B[I] = result - end - return B - end -end -# For BitArray outputs, we cache the result in a "small" Vector{Bool}, -# and then copy in chunks into the output -@generated function _broadcast!(f, B::BitArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N} - nargs = N + 1 - quote - $(Expr(:meta, :inline)) - # destructure the keeps and As tuples - A_1 = A - @nexprs $N i->(A_{i+1} = Bs[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - C = Vector{Bool}(undef, bitcache_size) - Bc = B.chunks - ind = 1 - cind = 1 - @simd for I in iter - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function and store the result - @inbounds C[ind] = @ncall $nargs f val - ind += 1 - if ind > bitcache_size - dumpbitcache(Bc, cind, C) - cind += bitcache_chunks - ind = 1 - end - end - if ind > 1 - @inbounds C[ind:bitcache_size] = false - dumpbitcache(Bc, cind, C) - end - return B - end +""" + _broadcast_getindex(A, I) + +Index into `A` with `I`, collapsing broadcasted indices to their singleton indices as appropriate +""" +Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices +Base.@propagate_inbounds _broadcast_getindex(::Ref{Type{T}}, I) where {T} = T +# Tuples are statically known to be singleton or vector-like +Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I) = A[1] +Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I) = A[I[1]] +# Everything else falls back to dynamically dropping broadcasted indices based upon its axes +Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)] + +# In some cases, it's more efficient to sort out which dimensions should be dropped +# ahead of time (often when the size checks aren't able to be lifted out of the loop). +# The Extruded struct computes that information ahead of time and stores it as a pair +# of tuples to optimize indexing later. This is most commonly needed for `Array` and +# other `AbstractArray` subtypes that wrap `Array` and dynamically ask it for its size. +struct Extruded{T, K, D} + x::T + keeps::K # A tuple of booleans, specifying which indices should be passed normally + defaults::D # A tuple of integers, specifying the index to use when keeps[i] is false (as defaults[i]) +end +@inline broadcast_axes(b::Extruded) = broadcast_axes(b.x) +Base.@propagate_inbounds _broadcast_getindex(b::Extruded, i) = b.x[newindex(i, b.keeps, b.defaults)] +extrude(x::AbstractArray) = Extruded(x, newindexer(x)...) +extrude(x) = x + +# For Broadcasted +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Any}, I) + args = _getindex(bc.args, I) + return _broadcast_getindex_evalf(bc.f, args...) +end +# Hack around losing Type{T} information in the final args tuple. Julia actually +# knows (in `code_typed`) the _value_ of these types, statically displaying them, +# but inference is currently skipping inferring the type of the types as they are +# transiently placed in a tuple as the argument list is lispily constructed. These +# additional methods recover type stability when a `Type` appears in one of the +# first two arguments of a function. +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Vararg{Any}}}, I) where {T} + args = _getindex(tail(bc.args), I) + return _broadcast_getindex_evalf(bc.f, T, args...) +end +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Any,Ref{Type{T}},Vararg{Any}}}, I) where {T} + arg1 = _broadcast_getindex(bc.args[1], I) + args = _getindex(tail(tail(bc.args)), I) + return _broadcast_getindex_evalf(bc.f, arg1, T, args...) +end +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Ref{Type{S}},Vararg{Any}}}, I) where {T,S} + args = _getindex(tail(tail(bc.args)), I) + return _broadcast_getindex_evalf(bc.f, T, S, args...) end +# Utilities for _broadcast_getindex +Base.@propagate_inbounds _getindex(args::Tuple, I) = (_broadcast_getindex(args[1], I), _getindex(tail(args), I)...) +Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = (_broadcast_getindex(args[1], I),) +Base.@propagate_inbounds _getindex(args::Tuple{}, I) = () + +@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) # not propagate_inbounds + """ broadcastable(x) @@ -410,129 +591,27 @@ julia> broadcastable("hello") # Strings break convention of matching iteration a Base.RefValue{String}("hello") ``` """ -broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing}) = Ref(x) +broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing,Val}) = Ref(x) broadcastable(x::Ptr) = Ref{Ptr}(x) # Cannot use Ref(::Ptr) until ambiguous deprecation goes through broadcastable(::Type{T}) where {T} = Ref{Type{T}}(T) -broadcastable(x::Union{AbstractArray,Number,Ref,Tuple}) = x +broadcastable(x::Union{AbstractArray,Number,Ref,Tuple,Broadcasted}) = x # In the future, default to collecting arguments. TODO: uncomment once deprecations are removed # broadcastable(x) = collect(x) # broadcastable(::Union{AbstractDict, NamedTuple}) = error("intentionally unimplemented to allow development in 1.x") -""" - broadcast!(f, dest, As...) +## Computation of inferred result type, for empty and concretely inferred cases only +_broadcast_getindex_eltype(bc::Broadcasted) = Base._return_type(bc.f, eltypes(bc.args)) +_broadcast_getindex_eltype(A) = eltype(A) # Tuple, Array, etc. -Like [`broadcast`](@ref), but store the result of -`broadcast(f, As...)` in the `dest` array. -Note that `dest` is only used to store the result, and does not supply -arguments to `f` unless it is also listed in the `As`, -as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. -""" -@inline function broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} - As′ = map(broadcastable, As) - broadcast!(f, dest, combine_styles(As′...), As′...) -end -@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...) +eltypes(::Tuple{}) = Tuple{} +eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])} +eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])} +eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...} -# Default behavior (separated out so that it can be called by users who want to extend broadcast!). -@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N - if f isa typeof(identity) && N == 1 - A = As[1] - if A isa AbstractArray && Base.axes(dest) == Base.axes(A) - return copyto!(dest, A) - end - end - _broadcast!(f, dest, As...) - return dest -end +# Inferred eltype of result of broadcast(f, args...) +combine_eltypes(f, args::Tuple) = Base._return_type(f, eltypes(args)) -# Optimization for the case where all arguments are 0-dimensional -@inline function broadcast!(f, dest, ::AbstractArrayStyle{0}, As::Vararg{Any, N}) where N - if dest isa AbstractArray - if f isa typeof(identity) && N == 1 - return fill!(dest, As[1][]) - else - @inbounds for I in eachindex(dest) - dest[I] = f(map(getindex, As)...) - end - return dest - end - end - _broadcast!(f, dest, As...) - return dest -end - -# For broadcasted assignments like `broadcast!(f, A, ..., A, ...)`, where `A` -# appears on both the LHS and the RHS of the `.=`, then we know we're only -# going to make one pass through the array, and even though `A` is aliasing -# against itself, the mutations won't affect the result as the indices on the -# LHS and RHS will always match. This is not true in general, but with the `.op=` -# syntax it's fairly common for an argument to be `===` a source. -broadcast_unalias(dest, src) = dest === src ? src : unalias(dest, src) - -# This indirection allows size-dependent implementations. -@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N - shape = broadcast_indices(C) - @boundscheck check_broadcast_indices(shape, A, Bs...) - A′ = broadcast_unalias(C, A) - Bs′ = map(B->broadcast_unalias(C, B), Bs) - keeps, Idefaults = map_newindexer(shape, A′, Bs′) - iter = CartesianIndices(shape) - _broadcast!(f, C, keeps, Idefaults, A′, Bs′, Val(N), iter) - return C -end - -# broadcast with element type adjusted on-the-fly. This widens the element type of -# B as needed (allocating a new container and copying previously-computed values) to -# accommodate any incompatible new elements. -@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Val{nargs}, iter, st, count) where {K,ID,AT,nargs} - quote - $(Expr(:meta, :noinline)) - # destructure the keeps and As tuples - @nexprs $nargs i->(A_i = As[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - while !done(iter, st) - I, st = next(iter, st) - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function - V = @ncall $nargs f val - # store the result - if V isa eltype(B) - @inbounds B[I] = V - else - # This element type doesn't fit in B. Allocate a new B with wider eltype, - # copy over old values, and continue - newB = Base.similar(B, promote_typejoin(eltype(B), typeof(V))) - for II in Iterators.take(iter, count) - newB[II] = B[II] - end - newB[I] = V - return _broadcast!(f, newB, keeps, Idefaults, As, Val(nargs), iter, st, count+1) - end - count += 1 - end - return B - end -end - -maptoTuple(f) = Tuple{} -maptoTuple(f, a, b...) = Tuple{f(a), maptoTuple(f, b...).types...} - -# An element type satisfying for all A: -# broadcast_getindex( -# combine_styles(A), -# A, broadcast_indices(A) -# )::_broadcast_getindex_eltype(A) -_broadcast_getindex_eltype(A) = _broadcast_getindex_eltype(combine_styles(A), A) -_broadcast_getindex_eltype(::BroadcastStyle, A) = eltype(A) # Tuple, Array, etc. -_broadcast_getindex_eltype(::DefaultArrayStyle{0}, ::Ref{T}) where {T} = T - -# Inferred eltype of result of broadcast(f, xs...) -combine_eltypes(f, A, As...) = - Base._return_type(f, maptoTuple(_broadcast_getindex_eltype, A, As...)) +## Broadcasting core """ broadcast(f, As...) @@ -610,77 +689,294 @@ julia> string.(("one","two","three","four"), ": ", 1:4) ``` """ -@inline function broadcast(f, A, Bs...) - A′ = broadcastable(A) - Bs′ = map(broadcastable, Bs) - broadcast(f, combine_styles(A′, Bs′...), nothing, nothing, A′, Bs′...) +broadcast(f::Tf, As...) where {Tf} = copy(instantiate(make(f, As...))) + +# special cases defined for performance +@inline broadcast(f, x::Number...) = f(x...) +@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...) + +""" + broadcast!(f, dest, As...) + +Like [`broadcast`](@ref), but store the result of +`broadcast(f, As...)` in the `dest` array. +Note that `dest` is only used to store the result, and does not supply +arguments to `f` unless it is also listed in the `As`, +as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. +""" +broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, make(f, As...)); dest) + +""" + Broadcast.materialize(bc) + +Take a lazy `Broadcasted` object and compute the result +""" +@inline materialize(bc::Broadcasted) = copy(instantiate(bc)) +materialize(x) = x +@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style} + return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +end +@inline function materialize!(dest, x) + return copyto!(dest, instantiate(Broadcasted(identity, (x,), axes(dest)))) end -# In the scalar case we unwrap the arguments and just call `f` -@inline broadcast(f, ::AbstractArrayStyle{0}, ::Nothing, ::Nothing, A, Bs...) = f(A[], map(getindex, Bs)...) +## general `copy` methods +@inline copy(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc[CartesianIndex()] +copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) = + throw(ArgumentError("broadcasting requires an assigned BroadcastStyle")) -@inline broadcast(f, s::BroadcastStyle, ::Nothing, ::Nothing, A, Bs...) = - broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...), A, Bs...) +const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict} -const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict} +@inline function copy(bc::Broadcasted{Style}) where {Style} + ElType = combine_eltypes(bc.f, bc.args) + if Base.isconcretetype(ElType) + # We can trust it and defer to the simpler `copyto!` + return copyto!(broadcast_similar(Style(), ElType, axes(bc), bc), bc) + end + # When ElType is not concrete, use narrowing. Use the first output + # value to determine the starting output eltype; copyto_nonleaf! + # will widen `dest` as needed to accommodate later values. + bc′ = preprocess(nothing, bc) + iter = CartesianIndices(axes(bc′)) + state = start(iter) + if done(iter, state) + # if empty, take the ElType at face value + return broadcast_similar(Style(), ElType, axes(bc′), bc′) + end + # Initialize using the first value + I, state = next(iter, state) + @inbounds val = bc′[I] + dest = broadcast_similar(Style(), typeof(val), axes(bc′), bc′) + @inbounds dest[I] = val + # Now handle the remaining values + return copyto_nonleaf!(dest, bc′, iter, state, 1) +end -@inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType - if !Base.isconcretetype(ElType) - return broadcast_nonleaf(f, s, ElType, inds, As...) +## general `copyto!` methods +# The most general method falls back to a method that replaces Style->Nothing +# This permits specialization on typeof(dest) without introducing ambiguities +@inline copyto!(dest::AbstractArray, bc::Broadcasted) = copyto!(dest, convert(Broadcasted{Nothing}, bc)) + +# Performance optimization for the Scalar case +@inline function copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) + if not_nested(bc) + if bc.f === identity && bc.args isa Tuple{Any} # only a single input argument to broadcast! + # broadcast!(identity, dest, val) is equivalent to fill!(dest, val) + return fill!(dest, bc.args[1][]) + else + args = bc.args + @inbounds for I in eachindex(dest) + dest[I] = bc.f(map(getindex, args)...) + end + return dest + end end - dest = broadcast_similar(f, s, ElType, inds, As...) - broadcast!(f, dest, As...) + # Fall back to the default implementation + return copyto!(dest, instantiate(bc)) end -@inline function broadcast(f, s::BroadcastStyle, ::Type{ElType}, inds::Indices, As...) where ElType - dest = broadcast_similar(f, s, ElType, inds, As...) - broadcast!(f, dest, As...) +# For broadcasted assignments like `broadcast!(f, A, ..., A, ...)`, where `A` +# appears on both the LHS and the RHS of the `.=`, then we know we're only +# going to make one pass through the array, and even though `A` is aliasing +# against itself, the mutations won't affect the result as the indices on the +# LHS and RHS will always match. This is not true in general, but with the `.op=` +# syntax it's fairly common for an argument to be `===` a source. +broadcast_unalias(dest, src) = dest === src ? src : unalias(dest, src) +broadcast_unalias(::Nothing, src) = src + +# Preprocessing a `Broadcasted` does two things: +# * unaliases any arguments from `dest` +# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices +@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes) +preprocess(dest, x) = extrude(broadcast_unalias(dest, x)) + +@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...) +preprocess_args(dest, args::Tuple{Any}) = (preprocess(dest, args[1]),) +preprocess_args(dest, args::Tuple{}) = () + +# Specialize this method if all you want to do is specialize on typeof(dest) +@inline function copyto!(dest::AbstractArray, bc::Broadcasted{Nothing}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && bc.args isa Tuple{<:AbstractArray} # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + bc′ = preprocess(dest, bc) + @simd for I in CartesianIndices(axes(bc′)) + @inbounds dest[I] = bc′[I] + end + return dest end -# When ElType is not concrete, use narrowing. Use the first element of each input to determine -# the starting output eltype; the _broadcast! method will widen `dest` as needed to -# accommodate later values. -function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::Indices, As...) where ElType - nargs = length(As) - iter = CartesianIndices(shape) - if isempty(iter) - return Base.similar(Array{ElType}, shape) +# Performance optimization: for BitArray outputs, we cache the result +# in a "small" Vector{Bool}, and then copy in chunks into the output +function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc) + tmp = Vector{Bool}(undef, bitcache_size) + destc = dest.chunks + ind = cind = 1 + bc′ = preprocess(dest, bc) + @simd for I in CartesianIndices(axes(bc′)) + @inbounds tmp[ind] = bc′[I] + ind += 1 + if ind > bitcache_size + dumpbitcache(destc, cind, tmp) + cind += bitcache_chunks + ind = 1 + end end - keeps, Idefaults = map_newindexer(shape, As) - st = start(iter) - I, st = next(iter, st) - val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...) - if val isa Bool - dest = Base.similar(BitArray, shape) - else - dest = Base.similar(Array{typeof(val)}, shape) + if ind > 1 + @inbounds tmp[ind:bitcache_size] = false + dumpbitcache(destc, cind, tmp) end - dest[I] = val - _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1) -end - -@inline broadcast(f, ::Style{Tuple}, ::Nothing, ::Nothing, A, Bs...) = - tuplebroadcast(f, longest_tuple(A, Bs...), A, Bs...) -@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} = - ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N)) -@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Ref{Type{T}}, As...) where {N,T} = - ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N)) -longest_tuple(A::Tuple, B::Tuple, Bs...) = longest_tuple(_longest_tuple(A, B), Bs...) -longest_tuple(A, B::Tuple, Bs...) = longest_tuple(B, Bs...) -longest_tuple(A::Tuple, B, Bs...) = longest_tuple(A, Bs...) -longest_tuple(A, B, Bs...) = longest_tuple(Bs...) -longest_tuple(A::Tuple) = A + return dest +end + +# For some BitArray operations, we can work at the level of chunks. The trivial +# implementation just walks over the UInt64 chunks in a linear fashion. +# This requires three things: +# 1. The function must be known to work at the level of chunks +# 2. The only arrays involved must be BitArrays or scalars +# 3. There must not be any broadcasting beyond scalar — all array sizes must match +# We could eventually allow for all broadcasting and other array types, but that +# requires very careful consideration of all the edge effects. +const ChunkableOp = Union{typeof(&), typeof(|), typeof(xor), typeof(~)} +const BroadcastedChunkableOp{Style<:Union{Nothing,BroadcastStyle}, Axes, F<:ChunkableOp, Args<:Tuple} = Broadcasted{Style,Axes,F,Args} +ischunkedbroadcast(R, bc::BroadcastedChunkableOp) = ischunkedbroadcast(R, bc.args) +ischunkedbroadcast(R, args) = false +ischunkedbroadcast(R, args::Tuple{<:BitArray,Vararg{Any}}) = size(R) == size(args[1]) && ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{<:Bool,Vararg{Any}}) = ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischunkedbroadcast(R, args[1]) && ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{}) = true + +liftchunks(::Tuple{}) = () +liftchunks(args::Tuple{<:BitArray,Vararg{Any}}) = (args[1].chunks, liftchunks(tail(args))...) +# Transform scalars to repeated scalars the size of a chunk +liftchunks(args::Tuple{<:Bool,Vararg{Any}}) = (ifelse(args[1], typemax(UInt64), UInt64(0)), liftchunks(tail(args))...) +ithchunk(i) = () +Base.@propagate_inbounds ithchunk(i, c::Vector{UInt64}, args...) = (c[i], ithchunk(i, args...)...) +Base.@propagate_inbounds ithchunk(i, b::UInt64, args...) = (b, ithchunk(i, args...)...) +function chunkedcopyto!(dest::BitArray, bc::Broadcasted) + isempty(dest) && return dest + f = flatten(bc) + args = liftchunks(f.args) + dc = dest.chunks + @simd for i in eachindex(dc) + @inbounds dc[i] = f.f(ithchunk(i, args...)...) + end + @inbounds dc[end] &= Base._msk_end(dest) + return dest +end + + +@noinline throwdm(axdest, axsrc) = + throw(DimensionMismatch("destination axes $axdest are not compatible with source axes $axsrc")) + +function copyto_nonleaf!(dest, bc::Broadcasted, iter, state, count) + T = eltype(dest) + while !done(iter, state) + I, state = next(iter, state) + @inbounds val = bc[I] + S = typeof(val) + if S <: T + @inbounds dest[I] = val + else + # This element type doesn't fit in dest. Allocate a new dest with wider eltype, + # copy over old values, and continue + newdest = Base.similar(dest, promote_typejoin(T, S)) + for II in Iterators.take(iter, count) + newdest[II] = dest[II] + end + newdest[I] = val + return copyto_nonleaf!(newdest, bc, iter, state, count+1) + end + count += 1 + end + return dest +end + +## Tuple methods + +@inline copy(bc::Broadcasted{Style{Tuple}}) = + tuplebroadcast(longest_tuple(nothing, bc.args), bc) +@inline tuplebroadcast(::NTuple{N,Any}, bc) where {N} = ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), Val(N)) +# This is a little tricky: find the longest tuple (first arg) within the list of arguments (second arg) +# Start with nothing as a placeholder and go until we find the first tuple in the argument list +longest_tuple(::Nothing, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(t[1], tail(t)) +# Or recurse through nested broadcast expressions +longest_tuple(::Nothing, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(nothing, t[1].args), tail(t)) +longest_tuple(::Nothing, t::Tuple) = longest_tuple(nothing, tail(t)) +# And then compare it against all other tuples we find in the argument list or nested broadcasts +longest_tuple(l::Tuple, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(_longest_tuple(l, t[1]), tail(t)) +longest_tuple(l::Tuple, t::Tuple) = longest_tuple(l, tail(t)) +longest_tuple(l::Tuple, ::Tuple{}) = l +longest_tuple(l::Tuple, t::Tuple{Broadcasted}) = longest_tuple(l, t[1].args) +longest_tuple(l::Tuple, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(l, t[1].args), tail(t)) # Support only 1-tuples and N-tuples where there are no conflicts in N _longest_tuple(A::Tuple{Any}, B::Tuple{Any}) = A -_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A -_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A _longest_tuple(A::Tuple{Any}, B::NTuple{N,Any}) where N = B +_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A +_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A @noinline _longest_tuple(A, B) = throw(DimensionMismatch("tuples $A and $B could not be broadcast to a common size")) -tuplebroadcast_getargs(::Tuple{}, k) = () -@inline tuplebroadcast_getargs(As, k) = - (_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...) +## scalar-range broadcast operations ## +# DefaultArrayStyle and \ are not available at the time of range.jl +make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) + +make(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r)) +# For #18336 we need to prevent promotion of the step type: +make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T = + StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T = + StepRangeLen{typeof(x+T(r.ref))}(x + r.ref, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRange(r.start + x, r.stop + x, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 + +make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T = + StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T = + StepRangeLen{typeof(x-T(r.ref))}(x - r.ref, -r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRange(r.start - x, r.stop - x, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 + +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = + StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) +# separate in case of noncommutative multiplication +make(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = + StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} = + StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(big), r::UnitRange) = big(r.start):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(big), r::LinRange) = LinRange(big(r.start), big(r.stop), length(r)) """ @@ -739,16 +1035,14 @@ julia> broadcast_getindex(A, [1 2 1; 1 2 2], [1, 2]) ``` """ broadcast_getindex(src::AbstractArray, I::AbstractArray...) = - broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_indices(I...)), - src, - I...) + broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_axes(I...)), src, I...) @generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...) N = length(I) Isplat = Expr[:(I[$d]) for d = 1:N] quote @nexprs $N d->(I_d = I[d]) - check_broadcast_indices(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly + check_broadcast_axes(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly checkbounds(src, $(Isplat...)) @nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == OneTo(1))) @nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin @@ -779,7 +1073,7 @@ See [`broadcast_getindex`](@ref) for examples of the treatment of `inds`. quote @nexprs $N d->(I_d = I[d]) checkbounds(A, $(Isplat...)) - shape = combine_indices($(Isplat...)) + shape = combine_axes($(Isplat...)) @nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d]) @nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == 1:1)) if !isa(x, AbstractArray) @@ -892,4 +1186,26 @@ macro __dot__(x) esc(__dot__(x)) end +@inline make_kwsyntax(f, args...; kwargs...) = make((args...)->f(args...; kwargs...), args...) +@inline function make(f, args...) + args′ = map(broadcastable, args) + make(combine_styles(args′...), f, args′...) +end +# Due to the current Type{T}/DataType specialization heuristics within Tuples, +# the totally generic varargs make(f, args...) method above loses Type{T}s in +# mapping broadcastable across the args. These additional methods with explicit +# arguments ensure we preserve Type{T}s in the first or second argument position. +@inline function make(f, arg1, args...) + arg1′ = broadcastable(arg1) + args′ = map(broadcastable, args) + make(combine_styles(arg1′, args′...), f, arg1′, args′...) +end +@inline function make(f, arg1, arg2, args...) + arg1′ = broadcastable(arg1) + arg2′ = broadcastable(arg2) + args′ = map(broadcastable, args) + make(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) +end +@inline make(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args) + end # module diff --git a/base/compiler/ssair/inlining2.jl b/base/compiler/ssair/inlining2.jl index 02c2fb435048a..085cb2733fb83 100644 --- a/base/compiler/ssair/inlining2.jl +++ b/base/compiler/ssair/inlining2.jl @@ -56,7 +56,7 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector if first_bb != block new_range = first_bb+1:block - bb_rename[new_range] = (1:length(new_range)) .+ length(new_cfg_blocks) + bb_rename[new_range] = (1+length(new_cfg_blocks)):(length(new_range)+length(new_cfg_blocks)) append!(new_cfg_blocks, map(copy, ir.cfg.blocks[new_range])) push!(merged_orig_blocks, last(new_range)) end @@ -79,12 +79,12 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector orig_succs = copy(new_cfg_blocks[end].succs) empty!(new_cfg_blocks[end].succs) if need_split_before - bb_rename_range = (1:length(inlinee_cfg.blocks)) .+ length(new_cfg_blocks) + bb_rename_range = (1+length(new_cfg_blocks)):(length(inlinee_cfg.blocks)+length(new_cfg_blocks)) push!(new_cfg_blocks[end].succs, length(new_cfg_blocks)+1) append!(new_cfg_blocks, inlinee_cfg.blocks) else # Merge the last block that was already there with the first block we're adding - bb_rename_range = (1:length(inlinee_cfg.blocks)) .+ (length(new_cfg_blocks) - 1) + bb_rename_range = length(new_cfg_blocks):(length(inlinee_cfg.blocks)+length(new_cfg_blocks)-1) append!(new_cfg_blocks[end].succs, inlinee_cfg.blocks[1].succs) append!(new_cfg_blocks, inlinee_cfg.blocks[2:end]) end @@ -130,7 +130,7 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector end end new_range = (first_bb + 1):length(ir.cfg.blocks) - bb_rename[new_range] = (1:length(new_range)) .+ length(new_cfg_blocks) + bb_rename[new_range] = (1+length(new_cfg_blocks)):(length(new_range)+length(new_cfg_blocks)) append!(new_cfg_blocks, ir.cfg.blocks[new_range]) # Rename edges original bbs diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index e18faa5ed9a2d..72cb54b13647e 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -371,12 +371,12 @@ function domsort_ssa!(ir::IRCode, domtree::DomTree) crit_edge_breaks_fixup = Tuple{Int, Int}[] for (new_bb, bb) in pairs(result_order) if bb == 0 - new_bbs[new_bb] = BasicBlock((1:1) .+ bb_start_off, [new_bb-1], [result_stmts[bb_start_off].dest]) + new_bbs[new_bb] = BasicBlock((bb_start_off+1):(bb_start_off+1), [new_bb-1], [result_stmts[bb_start_off].dest]) bb_start_off += 1 continue end old_inst_range = ir.cfg.blocks[bb].stmts - inst_range = (1:length(old_inst_range)) .+ bb_start_off + inst_range = (bb_start_off+1):(bb_start_off+length(old_inst_range)) inst_rename[old_inst_range] = Any[SSAValue(x) for x in inst_range] for (nidx, idx) in zip(inst_range, old_inst_range) stmt = ir.stmts[idx] diff --git a/base/deprecated.jl b/base/deprecated.jl index ec73aa06b4950..8efe89cbe319c 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -1115,6 +1115,10 @@ end @deprecate indices(a) axes(a) @deprecate indices(a, d) axes(a, d) +# And similar _indices names in Broadcast +@eval Broadcast Base.@deprecate_binding broadcast_indices broadcast_axes true +@eval Broadcast Base.@deprecate_binding check_broadcast_indices check_broadcast_axes false + # PR #25046 export reload, workspace reload(name::AbstractString) = error("`reload($(repr(name)))` is discontinued, consider Revise.jl for an alternative workflow.") diff --git a/base/float.jl b/base/float.jl index 6b468f238279d..db86677c6a8a0 100644 --- a/base/float.jl +++ b/base/float.jl @@ -875,13 +875,3 @@ float(r::StepRangeLen{T}) where {T} = function float(r::LinRange) LinRange(float(r.start), float(r.stop), length(r)) end - -# big, broadcast over arrays -# TODO: do the definitions below primarily pertaining to integers belong in float.jl? -function big end # no prior definitions of big in sysimg.jl, necessitating this -broadcast(::typeof(big), r::UnitRange) = big(r.start):big(last(r)) -broadcast(::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) -broadcast(::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) -function broadcast(::typeof(big), r::LinRange) - LinRange(big(r.start), big(r.stop), length(r)) -end diff --git a/base/range.jl b/base/range.jl index 950e44e300a55..91a8060ced324 100644 --- a/base/range.jl +++ b/base/range.jl @@ -719,67 +719,6 @@ end StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset) -(r::LinRange) = LinRange(-r.start, -r.stop, length(r)) -*(x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) -*(x::Number, r::StepRangeLen{T}) where {T} = - StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) -*(x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) -# separate in case of noncommutative multiplication -*(r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) -*(r::StepRangeLen{T}, x::Number) where {T} = - StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) -*(r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) - -/(r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) -/(r::StepRangeLen{T}, x::Number) where {T} = - StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset) -/(r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) -# also, separate in case of noncommutative multiplication (division) -\(x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=x\length(r)) -\(x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) -\(x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) - -## scalar-range broadcast operations ## - -broadcast(::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) -broadcast(::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) -broadcast(::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) - -broadcast(::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) -# For #18336 we need to prevent promotion of the step type: -broadcast(::typeof(+), x::Number, r::AbstractUnitRange) = range(x + first(r), step=step(r), length=length(r)) -broadcast(::typeof(+), x::Number, r::AbstractRange) = (x+first(r)):step(r):(x+last(r)) -function broadcast(::typeof(+), x::Number, r::StepRangeLen{T}) where T - newref = x + r.ref - StepRangeLen{typeof(T(r.ref) + x)}(newref, r.step, length(r), r.offset) -end -function broadcast(::typeof(+), x::Number, r::LinRange) - LinRange(x + r.start, x + r.stop, r.len) -end -broadcast(::typeof(+), r::AbstractRange, x::Number) = broadcast(+, x, r) # assumes addition is commutative - -broadcast(::typeof(-), x::Number, r::AbstractRange) = (x-first(r)):-step(r):(x-last(r)) -broadcast(::typeof(-), x::Number, r::StepRangeLen) = broadcast(+, x, -r) -function broadcast(::typeof(-), x::Number, r::LinRange) - LinRange(x - r.start, x - r.stop, r.len) -end - -broadcast(::typeof(-), r::AbstractRange, x::Number) = broadcast(+, -x, r) # assumes addition is commutative - -broadcast(::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) -broadcast(::typeof(*), x::Number, r::StepRangeLen) = StepRangeLen(x*r.ref, x*r.step, length(r), r.offset) -broadcast(::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) -# separate in case of noncommutative multiplication -broadcast(::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) -broadcast(::typeof(*), r::StepRangeLen, x::Number) = StepRangeLen(r.ref*x, r.step*x, length(r), r.offset) -broadcast(::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) - -broadcast(::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) -broadcast(::typeof(/), r::StepRangeLen, x::Number) = StepRangeLen(r.ref/x, r.step/x, length(r), r.offset) -broadcast(::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) -# also, separate in case of noncommutative multiplication (division) -broadcast(::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=x\length(r)) -broadcast(::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) -broadcast(::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) # promote eltype if at least one container wouldn't change, otherwise join container types. el_same(::Type{T}, a::Type{<:AbstractArray{T,n}}, b::Type{<:AbstractArray{T,n}}) where {T,n} = a @@ -851,8 +790,6 @@ promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} = promote_rule(::Type{LinRange{L}}, b::Type{StepRangeLen{T,R,S}}) where {L,T,R,S} = promote_rule(StepRangeLen{L,L,L}, b) -# +/- of ranges is defined in operators.jl (to be able to use @eval etc.) - ## concatenation ## function vcat(rs::AbstractRange{T}...) where T @@ -960,6 +897,3 @@ function +(r1::StepRangeLen{T,S}, r2::StepRangeLen{T,S}) where {T,S} end -(r1::StepRangeLen, r2::StepRangeLen) = +(r1, -r2) - -broadcast(::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 -broadcast(::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 diff --git a/base/reducedim.jl b/base/reducedim.jl index a556fbe3667aa..2b0dc06321cad 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -218,7 +218,7 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray) return R end indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) if reducedim1(R, A) # keep the accumulator as a local variable when reducing along the first dimension i1 = first(indices1(R)) @@ -667,7 +667,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N} # If we're reducing along dimension 1, for efficiency we can make use of a temporary. # Otherwise, keep the result in Rval/Rind so that we traverse A in storage order. indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval)) - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) ks = keys(A) k, kss = next(ks, start(ks)) zi = zero(eltype(ks)) diff --git a/base/sort.jl b/base/sort.jl index 586ec4dedbf53..adf90a5162b78 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -95,9 +95,12 @@ issorted(itr; function partialsort!(v::AbstractVector, k::Union{Int,OrdinalRange}, o::Ordering) inds = axes(v, 1) sort!(v, first(inds), last(inds), PartialQuickSort(k), o) - @views v[k] + maybeview(v, k) end +maybeview(v, k) = view(v, k) +maybeview(v, k::Integer) = v[k] + """ partialsort!(v, k; by=, lt=, rev=false) @@ -716,7 +719,7 @@ function partialsortperm!(ix::AbstractVector{<:Integer}, v::AbstractVector, # do partial quicksort sort!(ix, PartialQuickSort(k), Perm(ord(lt, by, rev, order), v)) - @views ix[k] + maybeview(ix, k) end ## sortperm: the permutation to sort an array ## diff --git a/base/statistics.jl b/base/statistics.jl index 3b0bbb5b9f9ac..350e64639a034 100644 --- a/base/statistics.jl +++ b/base/statistics.jl @@ -145,7 +145,7 @@ function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::Abstr return R end indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) if reducedim1(R, A) i1 = first(indices1(R)) @inbounds for IA in CartesianIndices(indsAt) diff --git a/doc/src/base/arrays.md b/doc/src/base/arrays.md index 3357292838d57..2b1a8d7236e56 100644 --- a/doc/src/base/arrays.md +++ b/doc/src/base/arrays.md @@ -69,7 +69,7 @@ For specializing broadcast on custom types, see ```@docs Base.BroadcastStyle Base.broadcast_similar -Base.broadcast_indices +Base.broadcast_axes Base.Broadcast.AbstractArrayStyle Base.Broadcast.ArrayStyle Base.Broadcast.DefaultArrayStyle diff --git a/doc/src/manual/interfaces.md b/doc/src/manual/interfaces.md index d13bd756da854..e237818334ed7 100644 --- a/doc/src/manual/interfaces.md +++ b/doc/src/manual/interfaces.md @@ -435,22 +435,22 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f -## [Broadcasting](@id man-interfaces-broadcasting) +## [Customizing broadcasting](@id man-interfaces-broadcasting) | Methods to implement | Brief description | |:-------------------- |:----------------- | | `Base.BroadcastStyle(::Type{SrcType}) = SrcStyle()` | Broadcasting behavior of `SrcType` | -| `Base.broadcast_similar(f, ::DestStyle, ::Type{ElType}, inds, As...)` | Allocation of output container | +| `Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc)` | Allocation of output container | | **Optional methods** | | | | `Base.BroadcastStyle(::Style1, ::Style2) = Style12()` | Precedence rules for mixing styles | -| `Base.broadcast_indices(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) | +| `Base.broadcast_axes(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) | | `Base.broadcastable(x)` | Convert `x` to an object that has `axes` and supports indexing | | **Bypassing default machinery** | | -| `broadcast(f, As...)` | Complete bypass of broadcasting machinery | -| `broadcast(f, ::DestStyle, ::Nothing, ::Nothing, As...)` | Bypass after container type is computed | -| `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed | -| `broadcast!(f, dest::DestType, ::Nothing, As...)` | Bypass in-place broadcast, specialization on destination type | -| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` | +| `Base.copy(bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast` | +| `Base.copyto!(dest, bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast!`, specializing on `DestStyle` | +| `Base.copyto!(dest::DestType, bc::Broadcasted{Nothing})` | Custom implementation of `broadcast!`, specializing on `DestType` | +| `Base.Broadcast.make(f, args...)` | Override the default lazy behavior within a fused expression | +| `Base.Broadcast.instantiate(bc::Broadcasted{DestStyle})` | Override the computation of the wrapper's axes and indexers | [Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by "dot" operations like `A .+ b` or `f.(x, y)`. Any object that has [`axes`](@ref) and supports @@ -463,16 +463,16 @@ in an `Array`. This basic framework is extensible in three major ways: Not all types support `axes` and indexing, but many are convenient to allow in broadcast. The [`Base.broadcastable`](@ref) function is called on each argument to broadcast, allowing -it to return something different that supports `axes` and indexing if it does not. By +it to return something different that supports `axes` and indexing. By default, this is the identity function for all `AbstractArray`s and `Number`s — they already support `axes` and indexing. For a handful of other types (including but not limited to types themselves, functions, special singletons like `missing` and `nothing`, and dates), `Base.broadcastable` returns the argument wrapped in a `Ref` to act as a 0-dimensional "scalar" for the purposes of broadcasting. Custom types can similarly specialize `Base.broadcastable` to define their shape, but they should follow the convention that -`collect(Base.broadcastable(x)) == collect(x)`. A notable exception are `AbstractString`s; -they are special-cased to behave as scalars for the purposes of broadcast even though they -are iterable collections of their characters. +`collect(Base.broadcastable(x)) == collect(x)`. A notable exception is `AbstractString`; +strings are special-cased to behave as scalars for the purposes of broadcast even though +they are iterable collections of their characters. The next two steps (selecting the output array and implementation) are dependent upon determining a single answer for a given set of arguments. Broadcast must take all the varied @@ -483,12 +483,11 @@ styles into a single answer — the "destination style". ### Broadcast Styles -`Base.BroadcastStyle` is the abstract type from which all styles are -derived. When used as a function it has two possible forms, -unary (single-argument) and binary. -The unary variant states that you intend to -implement specific broadcasting behavior and/or output type, -and do not wish to rely on the default fallback ([`Broadcast.DefaultArrayStyle`](@ref)). +`Base.BroadcastStyle` is the abstract type from which all broadcast styles are derived. When used as a +function it has two possible forms, unary (single-argument) and binary. The unary variant states +that you intend to implement specific broadcasting behavior and/or output type, and do not wish to +rely on the default fallback [`Broadcast.DefaultArrayStyle`](@ref). + To override these defaults, you can define a custom `BroadcastStyle` for your object: ```julia @@ -507,27 +506,30 @@ leverage one of the general broadcast wrappers: When your broadcast operation involves several arguments, individual argument styles get combined to determine a single `DestStyle` that controls the type of the output container. -For more detail, see [below](@ref writing-binary-broadcasting-rules). +For more details, see [below](@ref writing-binary-broadcasting-rules). ### Selecting an appropriate output array -The actual allocation of the result array is handled by `Base.broadcast_similar`: +The broadcast style is computed for every broadcasting operation to allow for +dispatch and specialization. The actual allocation of the result array is +handled by `Base.broadcast_similar`, using this style as its first argument. ```julia -Base.broadcast_similar(f, ::DestStyle, ::Type{ElType}, inds, As...) +Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc) ``` -`f` is the operation being performed and `DestStyle` signals the final result from -combining the input styles. -`As...` is the list of input objects. You may not need to use `f` or `As...` -unless they help you build the appropriate object; the fallback definition is +The fallback definition is ```julia -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, As...) where {N,ElType} = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} = similar(Array{ElType}, inds) ``` -However, if needed you can specialize on any or all of these arguments. +However, if needed you can specialize on any or all of these arguments. The final argument +`bc` is a lazy representation of a (potentially fused) broadcast operation, a `Broadcasted` +object. For these purposes, the most important fields of the wrapper are +`f` and `args`, describing the function and argument list, respectively. Note that the argument +list can — and often does — include other nested `Broadcasted` wrappers. For a complete example, let's say you have created a type, `ArrayAndChar`, that stores an array and a single character: @@ -553,20 +555,21 @@ Base.BroadcastStyle(::Type{<:ArrayAndChar}) = Broadcast.ArrayStyle{ArrayAndChar} ``` -This forces us to also define a `broadcast_similar` method: -```jldoctest ArrayAndChar; filter = r"(^find_aac \(generic function with 2 methods\)$|^$)" -function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{ArrayAndChar}, ::Type{ElType}, inds, As...) where ElType +This means we must also define a corresponding `broadcast_similar` method: +```jldoctest +function Base.broadcast_similar(::Broadcast.ArrayStyle{ArrayAndChar}, ::Type{ElType}, inds, bc) where ElType # Scan the inputs for the ArrayAndChar: - A = find_aac(As...) + A = find_aac(bc) # Use the char field of A to create the output ArrayAndChar(similar(Array{ElType}, inds), A.char) end -"`A = find_aac(As...)` returns the first ArrayAndChar among the arguments." -find_aac(A::ArrayAndChar, B...) = A -find_aac(A, B...) = find_aac(B...); -# output - +"`A = find_aac(As)` returns the first ArrayAndChar among the arguments." +find_aac(bc::Base.Broadcast.Broadcasted) = find_aac(bc.args) +find_aac(args::Tuple) = find_aac(find_aac(args[1]), Base.tail(args)) +find_aac(x) = x +find_aac(a::ArrayAndChar, rest) = a +find_aac(::Any, rest) = find_aac(rest) ``` From these definitions, one obtains the following behavior: @@ -589,58 +592,86 @@ julia> a .+ [5,10] ### [Extending broadcast with custom implementations](@id extending-in-place-broadcast) -Finally, it's worth noting that sometimes it's easier simply to bypass the machinery for -computing result types and container sizes, and just do everything manually. For example, -you can convert a `UnitRange{Int}` `r` to a `UnitRange{BigInt}` with `big.(r)`; the definition -of this method is approximately +In general, a broadcast operation is represented by a lazy `Broadcasted` container that holds onto +the function to be applied alongside its arguments. Those arguments may themselves be more nested +`Broadcasted` containers, forming a large expression tree to be evaluated. A nested tree of +`Broadcasted` containers is directly constructed by the implicit dot syntax; `5 .+ 2.*x` is +transiently represented by `Broadcasted(+, 5, Broadcasted(*, 2, x))`, for example. This is +invisible to users as it is immediately realized through a call to `copy`, but it is this container +that provides the basis for broadcast's extensibility for authors of custom types. The built-in +broadcast machinery will then determine the result type and size based upon the arguments, allocate +it, and then finally copy the realization of the `Broadcasted` object into it with a default +`copyto!(::AbstractArray, ::Broadcasted)` method. The built-in fallback `broadcast` and +`broadcast!` methods similarly construct a transient `Broadcasted` representation of the operation +so they can follow the same codepath. This allows custom array implementations to +provide their own `copyto!` specialization to customize and +optimize broadcasting. This is again determined by the computed broadcast style. This is such +an important part of the operation that it is stored as the first type parameter of the +`Broadcasted` type, allowing for dispatch and specialization. + +For some types, the machinery to "fuse" operations across nested levels of broadcasting +is not available or could be done more efficiently incrementally. In such cases, you may +need or want to evaluate `x .* (x .+ 1)` as if it had been +written `broadcast(*, x, broadcast(+, x, 1))`, where the inner operation is evaluated before +tackling the outer operation. This sort of eager operation is directly supported by a bit +of indirection; instead of directly constructing `Broadcasted` objects, Julia lowers the +fused expression `x .* (x .+ 1)` to `Broadcast.make(*, x, Broadcast.make(+, x, 1))`. Now, +by default, `make` just calls the `Broadcasted` constructor to create the lazy representation +of the fused expression tree, but you can choose to override it for a particular combination +of function and arguments. + +As an example, the builtin `AbstractRange` objects use this machinery to optimize pieces +of broadcasted expressions that can be eagerly evaluated purely in terms of the start, +step, and length (or stop) instead of computing every single element. Just like all the +other machinery, `make` also computes and exposes the combined broadcast style of its +arguments, so instead of specializing on `make(f, args...)`, you can specialize on +`make(::DestStyle, f, args...)` for any combination of style, function, and arguments. + +For example, the following definition supports the negation of ranges: ```julia -Broadcast.broadcast(::typeof(big), r::UnitRange) = big(first(r)):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) ``` -This exploits Julia's ability to dispatch on a particular function type. (This kind of -explicit definition can indeed be necessary if the output container does not support `setindex!`.) -You can optionally choose to implement the actual broadcasting yourself, but allow -the internal machinery to compute the container type, element type, and indices by specializing - -```julia -Broadcast.broadcast(::typeof(somefunction), ::MyStyle, ::Type{ElType}, inds, As...) -``` +### [Extending in-place broadcasting](@id extending-in-place-broadcast) -Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce -ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions. - -First, if you want to specialize on the destination type, say `DestType`, then you should -define a method with the following signature: +In-place broadcasting can be supported by defining the appropriate `copyto!(dest, bc::Broadcasted)` +method. Because you might want to specialize either on `dest` or the specific subtype of `bc`, +to avoid ambiguities between packages we recommend the following convention. +If you wish to specialize on a particular style `DestStyle`, define a method for ```julia -broadcast!(f, dest::DestType, ::Nothing, As...) +copyto!(dest, bc::Broadcasted{DestStyle}) ``` +Optionally, with this form you can also specialize on the type of `dest`. -Note that no bounds should be placed on the types of `f` and `As...`. - -Second, if specialized `broadcast!` behavior is desired depending on the input types, -you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to -determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following -signature: +If instead you want to specialize on the destination type `DestType` without specializing +on `DestStyle`, then you should define a method with the following signature: ```julia -broadcast!(f, dest, ::MyBroadcastStyle, As...) +copyto!(dest::DestType, bc::Broadcasted{Nothing}) ``` -Note the lack of bounds on `f`, `dest`, and `As...`. +This leverages a fallback implementation of `copyto!` that converts the wrapper into a +`Broadcasted{Nothing}`. Consequently, specializing on `DestType` has lower precedence than +methods that specialize on `DestStyle`. -Third, simultaneously specializing on both the type of `dest` and the `BroadcastStyle` is fine. In this case, -it is also allowed to specialize on the types of the source arguments (`As...`). For example, these method signatures are OK: +Similarly, you can completely override out-of-place broadcasting with a `copy(::Broadcasted)` +method. -```julia -broadcast!(f, dest::DestType, ::MyBroadcastStyle, As...) -broadcast!(f, dest::DestType, ::MyBroadcastStyle, As::AbstractArray...) -broadcast!(f, dest::DestType, ::Broadcast.DefaultArrayStyle{0}, As::Number...) -``` +#### Working with `Broadcasted` objects + +In order to implement such a `copy` or `copyto!`, method, of course, you must +work with the `Broadcasted` wrapper to compute each element. There are two main +ways of doing so: +* `Broadcast.flatten` recomputes the potentially nested operation into a single + function and flat list of arguments. You are responsible for implementing the + broadcasting shape rules yourself, but this may be helpful in limited situations. +* Iterating over the `CartesianIndices` of the `axes(::Broadcasted)` and using + indexing with the resulting `CartesianIndex` object to compute the result. -#### [Writing binary broadcasting rules](@id writing-binary-broadcasting-rules) +### [Writing binary broadcasting rules](@id writing-binary-broadcasting-rules) The precedence rules are defined by binary `BroadcastStyle` calls: diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index d911737cb088c..ea7956bb1f369 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1671,53 +1671,11 @@ `(block ,@stmts ,nuref)) expr)) -; fuse nested calls to expr == f.(args...) into a single broadcast call, +; lazily fuse nested calls to expr == f.(args...) into a single broadcast call, ; or a broadcast! call if lhs is non-null. (define (expand-fuse-broadcast lhs rhs) (define (fuse? e) (and (pair? e) (eq? (car e) 'fuse))) - (define (anyfuse? exprs) - (if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs))))) - (define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args - (define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy))) - ; (To do: optimize the case where f is already an anonymous function, in which - ; case we only need to hygienicize the arguments? But it is quite tricky - ; to fully handle splatted args, typed args, keywords, etcetera. And probably - ; the extra function call is harmless because it will get inlined anyway.) - (let ((genargs (map genarg args))) ; hygienic formal parameters - (if (null? kwargs) - `(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args - `(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs))))) - (define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func - (if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple) - (pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f)))) - (car (cdaddr f)) - f)) - (define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list - (define (fargs newargs oldargs) - (if (null? oldargs) - newargs - (fargs (if (fuse? (car oldargs)) - (append (reverse (caddar oldargs)) newargs) - (cons (car oldargs) newargs)) - (cdr oldargs)))) - (reverse (fargs '() oldargs))) - (define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f - ; any argument A of f that is (fuse g a) gets replaced by let A=(body of g): - (define (fuse-lets fargs args lets) - (if (null? args) - lets - (if (fuse? (car args)) - (fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets)) - (fuse-lets (cdr fargs) (cdr args) lets)))) - (let ((fargs (cdadr f)) - (fbody (caddr f))) - `(-> - (tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg) - `(fuse _ ,(cdadr (cadr arg))) - oldarg)) - fargs args))) - (let (block ,@(reverse (fuse-lets fargs args '()))) ,fbody)))) - (define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args) + (define (dot-to-fuse e (top #f)) ; convert e == (. f (tuple args)) to (fuse f args) (define (make-fuse f args) ; check for nested (fuse f args) exprs and combine (define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args (define (sk args kwargs pargs) @@ -1729,78 +1687,43 @@ (if (has-parameters? args) (sk (reverse (cdr args)) (cdar args) '()) (sk (reverse args) '() '()))) - (let* ((kws.args (split-kwargs args)) - (kws (car kws.args)) - (args (cdr kws.args)) ; fusing occurs on positional args only - (args_ (map dot-to-fuse args))) - (if (anyfuse? args_) - `(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_)) - `(fuse ,(to-lambda f args kws) ,args_)))) + (let* ((kws+args (split-kwargs args)) ; fusing occurs on positional args only + (kws (car kws+args)) + (kws (if (null? kws) kws (list (cons 'parameters kws)))) + (args (map dot-to-fuse (cdr kws+args))) + (make `(call (|.| (top Broadcast) ,(if (null? kws) ''make ''make_kwsyntax)) ,@kws ,f ,@args))) + (if top (cons 'fuse make) make))) (if (and (pair? e) (eq? (car e) '|.|)) (let ((f (cadr e)) (x (caddr e))) (cond ((or (atom? x) (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)) `(call (top getproperty) ,f ,x)) ((eq? (car x) 'tuple) - (make-fuse f (cdr x))) + (if (and (eq? f '^) (length= x 3) (integer? (caddr x))) + (make-fuse (expand-forms '(top literal_pow)) + (list '^ (cadr x) (expand-forms `(call (call (core apply_type) (top Val) ,(caddr x)))))) + (make-fuse f (cdr x)))) (else (error (string "invalid syntax \"" (deparse e) "\""))))) (if (and (pair? e) (eq? (car e) 'call) (dotop? (cadr e))) - (make-fuse (undotop (cadr e)) (cddr e)) + (let ((f (undotop (cadr e))) (x (cddr e))) + (if (and (eq? f '^) (length= x 2) (integer? (cadr x))) + (make-fuse (expand-forms '(top literal_pow)) + (list '^ (car x) (expand-forms `(call (call (core apply_type) (top Val) ,(cadr x)))))) + (make-fuse f x))) e))) - ; given e == (fuse lambda args), compress the argument list by removing (pure) - ; duplicates in args, inlining literals, and moving any varargs to the end: - (define (compress-fuse e) - (define (findfarg arg args fargs) ; for arg in args, return corresponding farg - (if (eq? arg (car args)) - (car fargs) - (findfarg arg (cdr args) (cdr fargs)))) - (if (fuse? e) - (let ((f (cadr e)) - (args (caddr e))) - (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) - (if (null? old-args) - (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) - (nargs (if (null? vararg) new-args (cons vararg new-args)))) - `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) - ,(reverse nargs))) - (let ((farg (car old-fargs)) (arg (car old-args))) - (cond - ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument - (if (null? varfarg) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args renames farg arg) - (if (eq? (cadr vararg) (cadr arg)) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) - varfarg vararg) - (error "multiple splatted args cannot be fused into a single broadcast")))) - ((julia-scalar? arg) ; inline numeric literals etc. - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg arg) renames) - varfarg vararg)) - ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args - ; (note: calling memq for every arg is O(length(args)^2) ... - ; ... would be better to replace with a hash table if args is long) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg (findfarg arg new-args new-fargs)) renames) - varfarg vararg)) - (else - (cf (cdr old-fargs) (cdr old-args) - (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) - (cf (cdadr f) args '() '() '() '() '())) - e)) ; (not (fuse? e)) - (let ((e (compress-fuse (dot-to-fuse rhs))) ; an expression '(fuse func args) if expr is a dot call + (let ((e (dot-to-fuse rhs #t)) ; an expression '(fuse func args) if expr is a dot call (lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place (if (fuse? e) + ; expanded to a fuse op call (if (null? lhs) - (expand-forms `(call (top broadcast) ,(from-lambda (cadr e)) ,@(caddr e))) - (expand-forms `(call (top broadcast!) ,(from-lambda (cadr e)) ,lhs-view ,@(caddr e)))) + (expand-forms `(call (|.| (top Broadcast) 'materialize) ,(cdr e))) + (expand-forms `(call (|.| (top Broadcast) 'materialize!) ,lhs-view ,(cdr e)))) + ; expanded to something else (like a getfield) (if (null? lhs) (expand-forms e) (expand-forms `(call (top broadcast!) (top identity) ,lhs-view ,e)))))) + (define (expand-where body var) (let* ((bounds (analyze-typevar var)) (v (car bounds))) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 5e8fd1ac1c517..2e6f801dc992c 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -19,6 +19,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof, @propagate_inbounds, @pure, reduce, typed_vcat +using Base.Broadcast: Broadcasted + # We use `_length` because of non-1 indices; releases after julia 0.5 # can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`. @@ -327,6 +329,7 @@ include("special.jl") include("bitarray.jl") include("ldlt.jl") include("schur.jl") +include("structuredbroadcast.jl") include("deprecated.jl") const ⋅ = dot diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index d0a5ed25523de..9d2d617b78023 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -174,8 +174,6 @@ AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A) convert(T::Type{<:Bidiagonal}, m::AbstractMatrix) = m isa T ? m : T(m) -broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo) - # For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix. # On the other hand, similar(B, [neweltype,] shape...) should yield a sparse matrix. # The first method below effects the former, and the second the latter. @@ -237,18 +235,9 @@ function size(M::Bidiagonal, d::Integer) end #Elementary operations -broadcast(::typeof(abs), M::Bidiagonal) = Bidiagonal(abs.(M.dv), abs.(M.ev), M.uplo) -broadcast(::typeof(round), M::Bidiagonal) = Bidiagonal(round.(M.dv), round.(M.ev), M.uplo) -broadcast(::typeof(trunc), M::Bidiagonal) = Bidiagonal(trunc.(M.dv), trunc.(M.ev), M.uplo) -broadcast(::typeof(floor), M::Bidiagonal) = Bidiagonal(floor.(M.dv), floor.(M.ev), M.uplo) -broadcast(::typeof(ceil), M::Bidiagonal) = Bidiagonal(ceil.(M.dv), ceil.(M.ev), M.uplo) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo) end -broadcast(::typeof(round), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(round.(T, M.dv), round.(T, M.ev), M.uplo) -broadcast(::typeof(trunc), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(trunc.(T, M.dv), trunc.(T, M.ev), M.uplo) -broadcast(::typeof(floor), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(floor.(T, M.dv), floor.(T, M.ev), M.uplo) -broadcast(::typeof(ceil), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(ceil.(T, M.dv), ceil.(T, M.ev), M.uplo) adjoint(B::Bidiagonal) = Adjoint(B) transpose(B::Bidiagonal) = Transpose(B) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1470fec31e406..ed01cadd9f92f 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -112,7 +112,6 @@ isposdef(D::Diagonal) = all(x -> x > 0, D.diag) factorize(D::Diagonal) = D -broadcast(::typeof(abs), D::Diagonal) = Diagonal(abs.(D.diag)) real(D::Diagonal) = Diagonal(real(D.diag)) imag(D::Diagonal) = Diagonal(imag(D.diag)) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl new file mode 100644 index 0000000000000..24c410b2a299b --- /dev/null +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -0,0 +1,180 @@ +## Broadcast styles +import Base.Broadcast +using Base.Broadcast: DefaultArrayStyle, broadcast_similar, tail + +struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end +StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() +StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() + +const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular} +Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix} = StructuredMatrixStyle{T}() + +# Promotion of broadcasts between structured matrices. This is slightly unusual +# as we define them symmetrically. This allows us to have a fallback to DefaultArrayStyle{2}(). +# Diagonal can cavort with all the other structured matrix types. +# Bidiagonal doesn't know if it's upper or lower, so it becomes Tridiagonal +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Diagonal}) = + StructuredMatrixStyle{Diagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() + +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Bidiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() + +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() + +# All other combinations fall back to the default style +Broadcast.BroadcastStyle(::StructuredMatrixStyle, ::StructuredMatrixStyle) = DefaultArrayStyle{2}() + +# And a definition akin to similar using the structured type: +structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElType} = + Diagonal(Array{ElType}(undef, n)) +# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion +# system will return Tridiagonal when there's more than one Bidiagonal, but when +# there's only one, we need to make figure out upper or lower +find_bidiagonal() = throw(ArgumentError("could not find Bidiagonal within broadcast expression")) +find_bidiagonal(a::Bidiagonal, rest...) = a +find_bidiagonal(bc::Broadcast.Broadcasted, rest...) = find_bidiagonal(find_bidiagonal(bc.args...), rest...) +find_bidiagonal(x, rest...) = find_bidiagonal(rest...) +function structured_broadcast_alloc(bc, ::Type{<:Bidiagonal}, ::Type{ElType}, n) where {ElType} + ex = find_bidiagonal(bc) + return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), ex.uplo) +end +structured_broadcast_alloc(bc, ::Type{<:SymTridiagonal}, ::Type{ElType}, n) where {ElType} = + SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1)) +structured_broadcast_alloc(bc, ::Type{<:Tridiagonal}, ::Type{ElType}, n) where {ElType} = + Tridiagonal(Array{ElType}(undef, n-1),Array{ElType}(undef, n),Array{ElType}(undef, n-1)) +structured_broadcast_alloc(bc, ::Type{<:LowerTriangular}, ::Type{ElType}, n) where {ElType} = + LowerTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UpperTriangular}, ::Type{ElType}, n) where {ElType} = + UpperTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UnitLowerTriangular}, ::Type{ElType}, n) where {ElType} = + UnitLowerTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} = + UnitUpperTriangular(Array{ElType}(undef, n, n)) + +# A _very_ limited list of structure-preserving functions known at compile-time. This list is +# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must +# preserve both zeros and ones (for Unit***erTriangular) and symmetry (for SymTridiagonal) +const TypeFuncs = Union{typeof(round),typeof(trunc),typeof(floor),typeof(ceil)} +isstructurepreserving(bc::Broadcasted) = isstructurepreserving(bc.f, bc.args...) +isstructurepreserving(::Union{typeof(abs),typeof(big)}, ::StructuredMatrix) = true +isstructurepreserving(::TypeFuncs, ::StructuredMatrix) = true +isstructurepreserving(::TypeFuncs, ::Ref{<:Type}, ::StructuredMatrix) = true +isstructurepreserving(f, args...) = false + +_iszero(n::Number) = iszero(n) +_iszero(x) = x == 0 +fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && _iszero(v)) +# Very conservatively only allow Numbers and Types in this speculative zero-test pass +fzero(x::Number) = x +fzero(::Type{T}) where T = T +fzero(S::StructuredMatrix) = zero(eltype(S)) +fzero(x) = missing +function fzero(bc::Broadcast.Broadcasted) + args = map(fzero, bc.args) + return any(ismissing, args) ? missing : bc.f(args...) +end + +function Broadcast.broadcast_similar(::StructuredMatrixStyle{T}, ::Type{ElType}, inds, bc) where {T,ElType} + if isstructurepreserving(bc) || (fzeropreserving(bc) && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular})) + return structured_broadcast_alloc(bc, T, ElType, length(inds[1])) + end + return broadcast_similar(DefaultArrayStyle{2}(), ElType, inds, bc) +end + +function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.diag[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + return dest +end + +function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + if dest.uplo == 'U' + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + else + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + end + return dest +end + +function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + return dest +end + +function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.d[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.du[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + dest.dl[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + return dest +end + +function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for j in axs[2] + for i in j:axs[1][end] + dest.data[i,j] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, j)) + end + end + return dest +end + +function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for j in axs[2] + for i in 1:j + dest.data[i,j] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, j)) + end + end + return dest +end + +# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check +function map(f, A::StructuredMatrix, Bs::StructuredMatrix...) + sz = size(A) + all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match")) + return f.(A, Bs...) +end diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index b6bfcb81b13ad..c6ef222cacc7b 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -37,11 +37,8 @@ for t in (:LowerTriangular, :UnitLowerTriangular, :UpperTriangular, copy(A::$t) = $t(copy(A.data)) - broadcast(::typeof(big), A::$t) = $t(big.(A.data)) - real(A::$t{<:Real}) = A real(A::$t{<:Complex}) = (B = real(A.data); $t(B)) - broadcast(::typeof(abs), A::$t) = $t(abs.(A.data)) end end diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 1baf99202c5e1..e29266452a372 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -115,18 +115,9 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T # similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) #Elementary operations -broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev)) -broadcast(::typeof(round), M::SymTridiagonal) = SymTridiagonal(round.(M.dv), round.(M.ev)) -broadcast(::typeof(trunc), M::SymTridiagonal) = SymTridiagonal(trunc.(M.dv), trunc.(M.ev)) -broadcast(::typeof(floor), M::SymTridiagonal) = SymTridiagonal(floor.(M.dv), floor.(M.ev)) -broadcast(::typeof(ceil), M::SymTridiagonal) = SymTridiagonal(ceil.(M.dv), ceil.(M.ev)) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev)) end -broadcast(::typeof(round), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(round.(T, M.dv), round.(T, M.ev)) -broadcast(::typeof(trunc), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(trunc.(T, M.dv), trunc.(T, M.ev)) -broadcast(::typeof(floor), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(floor.(T, M.dv), floor.(T, M.ev)) -broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(ceil.(T, M.dv), ceil.(T, M.ev)) transpose(S::SymTridiagonal) = S adjoint(S::SymTridiagonal{<:Real}) = S @@ -497,24 +488,11 @@ similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), sim copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest) #Elementary operations -broadcast(::typeof(abs), M::Tridiagonal) = Tridiagonal(abs.(M.dl), abs.(M.d), abs.(M.du)) -broadcast(::typeof(round), M::Tridiagonal) = Tridiagonal(round.(M.dl), round.(M.d), round.(M.du)) -broadcast(::typeof(trunc), M::Tridiagonal) = Tridiagonal(trunc.(M.dl), trunc.(M.d), trunc.(M.du)) -broadcast(::typeof(floor), M::Tridiagonal) = Tridiagonal(floor.(M.dl), floor.(M.d), floor.(M.du)) -broadcast(::typeof(ceil), M::Tridiagonal) = Tridiagonal(ceil.(M.dl), ceil.(M.d), ceil.(M.du)) for func in (:conj, :copy, :real, :imag) @eval function ($func)(M::Tridiagonal) Tridiagonal(($func)(M.dl), ($func)(M.d), ($func)(M.du)) end end -broadcast(::typeof(round), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(round.(T, M.dl), round.(T, M.d), round.(T, M.du)) -broadcast(::typeof(trunc), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(trunc.(T, M.dl), trunc.(T, M.d), trunc.(T, M.du)) -broadcast(::typeof(floor), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(floor.(T, M.dl), floor.(T, M.d), floor.(T, M.du)) -broadcast(::typeof(ceil), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(ceil.(T, M.dl), ceil.(T, M.d), ceil.(T, M.du)) adjoint(S::Tridiagonal) = Adjoint(S) transpose(S::Tridiagonal) = Transpose(S) @@ -577,6 +555,7 @@ function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::Ab i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s) end + #tril and triu istriu(M::Tridiagonal) = iszero(M.dl) diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index a1644e951100c..5c2ba8720111f 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -208,10 +208,10 @@ end \(x::Number, J::UniformScaling) = UniformScaling(x\J.λ) -broadcast(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) -broadcast(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) +Broadcast.make(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) +Broadcast.make(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) -broadcast(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) +Broadcast.make(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) ==(J1::UniformScaling,J2::UniformScaling) = (J1.λ == J2.λ) diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl new file mode 100644 index 0000000000000..c8bef049fd01a --- /dev/null +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -0,0 +1,101 @@ +module TestStructuredBroadcast +using Test, LinearAlgebra + +@testset "broadcast[!] over combinations of scalars, structured matrices, and dense vectors/matrices" begin + N = 10 + s = rand() + fV = rand(N) + fA = rand(N, N) + Z = copy(fA) + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + U = UpperTriangular(rand(N,N)) + L = LowerTriangular(rand(N,N)) + structuredarrays = (D, B, T, U, L) + fstructuredarrays = map(Array, structuredarrays) + for (X, fX) in zip(structuredarrays, fstructuredarrays) + @test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX)) + @test broadcast!(sin, Z, X) == broadcast(sin, fX) + @test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX)) + @test broadcast!(cos, Z, X) == broadcast(cos, fX) + @test (Q = broadcast(*, s, X); typeof(Q) == typeof(X) && Q == broadcast(*, s, fX)) + @test broadcast!(*, Z, s, X) == broadcast(*, s, fX) + @test (Q = broadcast(+, fV, fA, X); Q isa Matrix && Q == broadcast(+, fV, fA, fX)) + @test broadcast!(+, Z, fV, fA, X) == broadcast(+, fV, fA, fX) + @test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX)) + @test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX) + for (Y, fY) in zip(structuredarrays, fstructuredarrays) + @test broadcast(+, X, Y) == broadcast(+, fX, fY) + @test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY) + @test broadcast(*, X, Y) == broadcast(*, fX, fY) + @test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end + diagonals = (D, B, T) + fdiagonals = map(Array, diagonals) + for (X, fX) in zip(diagonals, fdiagonals) + for (Y, fY) in zip(diagonals, fdiagonals) + @test broadcast(+, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(+, fX, fY) + @test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY) + @test broadcast(*, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(*, fX, fY) + @test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end +end + +@testset "broadcast! where the destination is a structured matrix" begin + N = 5 + A = rand(N, N) + sA = A + copy(A') + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + @test broadcast!(sin, copy(D), D) == Diagonal(sin.(D)) + @test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U) + @test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T)) + @test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A)) + @test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U) + @test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A)) +end + +@testset "map[!] over combinations of structured matrices" begin + N = 10 + fA = rand(N, N) + Z = copy(fA) + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + U = UpperTriangular(rand(N,N)) + L = LowerTriangular(rand(N,N)) + structuredarrays = (D, B, T, U, L) + fstructuredarrays = map(Array, structuredarrays) + for (X, fX) in zip(structuredarrays, fstructuredarrays) + @test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX)) + @test map!(sin, Z, X) == map(sin, fX) + @test (Q = map(cos, X); Q isa Matrix && Q == map(cos, fX)) + @test map!(cos, Z, X) == map(cos, fX) + @test (Q = map(+, fA, X); Q isa Matrix && Q == map(+, fA, fX)) + @test map!(+, Z, fA, X) == map(+, fA, fX) + for (Y, fY) in zip(structuredarrays, fstructuredarrays) + @test map(+, X, Y) == map(+, fX, fY) + @test map!(+, Z, X, Y) == map(+, fX, fY) + @test map(*, X, Y) == map(*, fX, fY) + @test map!(*, Z, X, Y) == map(*, fX, fY) + @test map(+, X, fA, Y) == map(+, fX, fA, fY) + @test map!(+, Z, X, fA, Y) == map(+, fX, fA, fY) + end + end + diagonals = (D, B, T) + fdiagonals = map(Array, diagonals) + for (X, fX) in zip(diagonals, fdiagonals) + for (Y, fY) in zip(diagonals, fdiagonals) + @test map(+, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(+, fX, fY) + @test map!(+, Z, X, Y) == broadcast(+, fX, fY) + @test map(*, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(*, fX, fY) + @test map!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end +end + +end diff --git a/stdlib/SparseArrays/src/higherorderfns.jl b/stdlib/SparseArrays/src/higherorderfns.jl index bca1585985ad8..3ee447c8a2c54 100644 --- a/stdlib/SparseArrays/src/higherorderfns.jl +++ b/stdlib/SparseArrays/src/higherorderfns.jl @@ -4,15 +4,16 @@ module HigherOrderFns # This module provides higher order functions specialized for sparse arrays, # particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present. -import Base: map, map!, broadcast, broadcast! +import Base: map, map!, broadcast, copy, copyto! using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange -using Base.Broadcast: BroadcastStyle +using Base.Broadcast: BroadcastStyle, Broadcasted, flatten using LinearAlgebra # This module is organized as follows: +# (0) Define BroadcastStyle rules and convenience types for dispatch # (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for # map[!]/broadcast[!]'s purposes. The methods below are written against this interface. # (2) Define entry points for map[!] (short children of _map_[not]zeropres!). @@ -29,11 +30,79 @@ using LinearAlgebra # (12) Define map[!] methods handling combinations of sparse and structured matrices. +# (0) BroadcastStyle rules and convenience types for dispatch + +SparseVecOrMat = Union{SparseVector,SparseMatrixCSC} + +# broadcast container type promotion for combinations of sparse arrays and other types +struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end +struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end +Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle() +Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle() +const SPVM = Union{SparseVecStyle,SparseMatStyle} + +# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions. +# SparseVecStyle promotes to SparseMatStyle for 2 dimensions. +# Fall back to DefaultArrayStyle for higher dimensionality. +SparseVecStyle(::Val{0}) = SparseVecStyle() +SparseVecStyle(::Val{1}) = SparseVecStyle() +SparseVecStyle(::Val{2}) = SparseMatStyle() +SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +SparseMatStyle(::Val{0}) = SparseMatStyle() +SparseMatStyle(::Val{1}) = SparseMatStyle() +SparseMatStyle(::Val{2}) = SparseMatStyle() +SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle() + +# Tuples promote to dense +Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}() +Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() + +struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end +PromoteToSparse(::Val{0}) = PromoteToSparse() +PromoteToSparse(::Val{1}) = PromoteToSparse() +PromoteToSparse(::Val{2}) = PromoteToSparse() +PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} +Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() +Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() + +Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() + +Broadcast.BroadcastStyle(::SPVM, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse() + +Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() + +# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray +# could report itself as a DefaultArrayStyle(). +# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details +is_supported_sparse_broadcast() = true +is_supported_sparse_broadcast(::AbstractArray, rest...) = false +is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...) +is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) + +# Dispatch on broadcast operations by number of arguments +const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} = + Broadcasted{Style,Axes,F,Tuple{}} +const SpBroadcasted1{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat}} = + Broadcasted{Style,Axes,F,Args} +const SpBroadcasted2{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat,SparseVecOrMat}} = + Broadcasted{Style,Axes,F,Args} + # (1) The definitions below provide a common interface to sparse vectors and matrices # sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors # as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views # sparse vectors in practice. -SparseVecOrMat = Union{SparseVector,SparseMatrixCSC} @inline numrows(A::SparseVector) = A.n @inline numrows(A::SparseMatrixCSC) = A.m @inline numcols(A::SparseVector) = 1 @@ -85,18 +154,18 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N fofzeros = f(_zeros_eltypes(A, Bs...)...) fpreszeros = _iszero(fofzeros) maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A) - entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...) + entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...)) indextypeC = _promote_indtype(A, Bs...) C = _allocres(size(A), indextypeC, entrytypeC, maxnnzC) return fpreszeros ? _map_zeropres!(f, C, A, Bs...) : _map_notzeropres!(f, fofzeros, C, A, Bs...) end # (3) broadcast[!] entry points -broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A) -broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) +copy(bc::SpBroadcasted1) = _noshapecheck_map(bc.f, bc.args[1]) -@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf +@inline function copyto!(C::SparseVecOrMat, bc::Broadcasted0{Nothing}) isempty(C) && return _finishempty!(C) + f = bc.f fofnoargs = f() if _iszero(fofnoargs) # f() is zero, so empty C trimstorage!(C, 0) @@ -109,19 +178,12 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) return C end -# the following three similar defs are necessary for type stability in the mixed vector/matrix case -broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} = - _aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...) -broadcast(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf,N} = - _aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...) -broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} = - _diffshape_broadcast(f, A, Bs...) function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} fofzeros = f(_zeros_eltypes(A, Bs...)...) fpreszeros = _iszero(fofzeros) indextypeC = _promote_indtype(A, Bs...) - entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...) - shapeC = to_shape(Base.Broadcast.combine_indices(A, Bs...)) + entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...)) + shapeC = to_shape(Base.Broadcast.combine_axes(A, Bs...)) maxnnzC = fpreszeros ? _checked_maxnnzbcres(shapeC, A, Bs...) : _densennz(shapeC) C = _allocres(shapeC, indextypeC, entrytypeC, maxnnzC) return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) : @@ -141,6 +203,10 @@ end @inline _aresameshape(A, B) = size(A) == size(B) @inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false @inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match")) +@inline _all_args_isa(t::Tuple{Any}, ::Type{T}) where T = isa(t[1], T) +@inline _all_args_isa(t::Tuple{Any,Vararg{Any}}, ::Type{T}) where T = isa(t[1], T) & _all_args_isa(tail(t), T) +@inline _all_args_isa(t::Tuple{Broadcasted}, ::Type{T}) where T = _all_args_isa(t[1].args, T) +@inline _all_args_isa(t::Tuple{Broadcasted,Vararg{Any}}, ::Type{T}) where T = _all_args_isa(t[1].args, T) & _all_args_isa(tail(t), T) @inline _densennz(shape::NTuple{1}) = shape[1] @inline _densennz(shape::NTuple{2}) = shape[1] * shape[2] _maxnnzfrom(shape::NTuple{1}, A) = nnz(A) * div(shape[1], A.n) @@ -887,37 +953,56 @@ end # (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices -# broadcast container type promotion for combinations of sparse arrays and other types -struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end -struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end -Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle() -Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle() -const SPVM = Union{SparseVecStyle,SparseMatStyle} +# broadcast entry points for combinations of sparse arrays and other (scalar) types +@inline function copy(bc::Broadcasted{<:SPVM}) + bcf = flatten(bc) + return _copy(bcf.f, bcf.args...) +end -# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions. -# SparseVecStyle promotes to SparseMatStyle for 2 dimensions. -# Fall back to DefaultArrayStyle for higher dimensionality. -SparseVecStyle(::Val{0}) = SparseVecStyle() -SparseVecStyle(::Val{1}) = SparseVecStyle() -SparseVecStyle(::Val{2}) = SparseMatStyle() -SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() -SparseMatStyle(::Val{0}) = SparseMatStyle() -SparseMatStyle(::Val{1}) = SparseMatStyle() -SparseMatStyle(::Val{2}) = SparseMatStyle() -SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +_copy(f, args::SparseVector...) = _shapecheckbc(f, args...) +_copy(f, args::SparseMatrixCSC...) = _shapecheckbc(f, args...) +_copy(f, args::SparseVecOrMat...) = _diffshape_broadcast(f, args...) +# Otherwise, we incorporate scalars into the function and re-dispatch +function _copy(f, args...) + parevalf, passedargstup = capturescalars(f, args) + return _copy(parevalf, passedargstup...) +end -Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle() +function _shapecheckbc(f, args...) + _aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...) +end -# Tuples promote to dense -Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}() -Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() -# broadcast entry points for combinations of sparse arrays and other (scalar) types -function broadcast(f, ::SPVM, ::Nothing, ::Nothing, mixedargs::Vararg{Any,N}) where N - parevalf, passedargstup = capturescalars(f, mixedargs) - return broadcast(parevalf, passedargstup...) +@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{<:SPVM}) + if bc.f === identity && bc isa SpBroadcasted1 && Base.axes(dest) == (A = bc.args[1]; Base.axes(A)) + return copyto!(dest, A) + end + bcf = flatten(bc) + As = map(arg->Base.unalias(dest, arg), bcf.args) + return _copyto!(bcf.f, dest, As...) +end + +@inline function _copyto!(f, dest, As::SparseVecOrMat...) + _aresameshape(dest, As...) && return _noshapecheck_map!(f, dest, As...) + Base.Broadcast.check_broadcast_axes(axes(dest), As...) + fofzeros = f(_zeros_eltypes(As...)...) + if _iszero(fofzeros) + return _broadcast_zeropres!(f, dest, As...) + else + return _broadcast_notzeropres!(f, fofzeros, dest, As...) + end +end + +@inline function _copyto!(f, dest, args...) + # args contains nothing but SparseVecOrMat and scalars + # See below for capturescalars + parevalf, passedsrcargstup = capturescalars(f, args) + _copyto!(parevalf, dest, passedsrcargstup...) +end + +struct CapturedScalars{F, Args, Order} + args::Args end -# for broadcast! see (11) # capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and # broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially @@ -930,6 +1015,13 @@ end return (parevalf, passedsrcargstup) end end +# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates +@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) +@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...)) +@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs))) nonscalararg(::SparseVecOrMat) = true nonscalararg(::Any) = false @@ -942,11 +1034,17 @@ end @inline function _capturescalars(arg, mixedargs...) let (rest, f) = _capturescalars(mixedargs...) if nonscalararg(arg) - return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast + return (arg, rest...), @inline function(head, tail...) + (head, f(tail...)...) + end # pass-through to broadcast elseif scalarwrappedarg(arg) - return rest, (tail...) -> (arg[], f(tail...)...) # unwrap and add back scalararg after (in makeargs) + return rest, @inline function(tail...) + (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple + end # unwrap and add back scalararg after (in makeargs) else - return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs) + return rest, @inline function(tail...) + (arg, f(tail...)...) + end # add back scalararg after (in makeargs) end end end @@ -972,69 +1070,18 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( # vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse # and rebroadcast. otherwise, divert to generic AbstractArray broadcast code. -struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end -PromoteToSparse(::Val{0}) = PromoteToSparse() -PromoteToSparse(::Val{1}) = PromoteToSparse() -PromoteToSparse(::Val{2}) = PromoteToSparse() -PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() - -const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} -Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse() -Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() -Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() - -Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() - -Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() -Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() - -# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray -# could report itself as a DefaultArrayStyle(). -# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details -is_supported_sparse_broadcast() = true -is_supported_sparse_broadcast(::AbstractArray, rest...) = false -is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...) -is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) -function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} - if is_supported_sparse_broadcast(As...) - return broadcast(f, map(_sparsifystructured, As)...) +function copy(bc::Broadcasted{PromoteToSparse}) + bcf = flatten(bc) + if is_supported_sparse_broadcast(bcf.args...) + broadcast(bcf.f, map(_sparsifystructured, bcf.args)...) else - return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...) + return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{2}}, bc)) end end -# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether -# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars, -# we can handle it here, otherwise see below for the promotion machinery. -function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} - if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A) - return copyto!(dest, A) - end - A′ = Base.unalias(dest, A) - Bs′ = map(B->Base.unalias(dest, B), Bs) - _aresameshape(dest, A′, Bs′...) && return _noshapecheck_map!(f, dest, A′, Bs′...) - Base.Broadcast.check_broadcast_indices(axes(dest), A′, Bs′...) - fofzeros = f(_zeros_eltypes(A′, Bs′...)...) - fpreszeros = _iszero(fofzeros) - fpreszeros ? _broadcast_zeropres!(f, dest, A′, Bs′...) : - _broadcast_notzeropres!(f, fofzeros, dest, A′, Bs′...) - return dest -end -function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N} - # mixedsrcargs contains nothing but SparseVecOrMat and scalars - parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs) - broadcast!(parevalf, dest, passedsrcargstup...) - return dest -end -function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N} - broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...) - return dest +@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse}) + bcf = flatten(bc) + broadcast!(bcf.f, dest, map(_sparsifystructured, bcf.args)...) end _sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M) @@ -1047,8 +1094,7 @@ _sparsifystructured(x) = x # (12) map[!] over combinations of sparse and structured matrices -SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix} -map(f::Tf, A::StructuredMatrix) where {Tf} = _noshapecheck_map(f, _sparsifystructured(A)) +SparseOrStructuredMatrix = Union{SparseMatrixCSC,LinearAlgebra.StructuredMatrix} map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} = (_checksameshape(A, Bs...); _noshapecheck_map(f, _sparsifystructured(A), map(_sparsifystructured, Bs)...)) map!(f::Tf, C::SparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} = diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index b8e2c26d33349..8744f80a39dbe 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -125,9 +125,9 @@ end @test broadcast!(cos, Z, X) == sparse(broadcast!(cos, fZ, fX)) # --> test shape checks for broadcast! entry point # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...)) + Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...)) catch @test_throws DimensionMismatch broadcast!(sin, Z, spzeros((shapeX .- 1)...)) end @@ -149,9 +149,9 @@ end @test broadcast!(cos, V, X) == sparse(broadcast!(cos, fV, fX)) # --> test shape checks for broadcast! entry point # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(V), spzeros((shapeX .- 1)...)) + Base.Broadcast.check_broadcast_axes(axes(V), spzeros((shapeX .- 1)...)) catch @test_throws DimensionMismatch broadcast!(sin, V, spzeros((shapeX .- 1)...)) end @@ -184,9 +184,9 @@ end @test broadcast(*, X, Y) == sparse(broadcast(*, fX, fY)) @test broadcast(f, X, Y) == sparse(broadcast(f, fX, fY)) # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y) + Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y) catch @test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y) end @@ -207,9 +207,9 @@ end @test broadcast!(f, Z, X, Y) == sparse(broadcast!(f, fZ, fX, fY)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...), Y) + Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...), Y) catch @test_throws DimensionMismatch broadcast!(f, Z, spzeros((shapeX .- 1)...), Y) end @@ -247,9 +247,9 @@ end @test broadcast(*, X, Y, Z) == sparse(broadcast(*, fX, fY, fZ)) @test broadcast(f, X, Y, Z) == sparse(broadcast(f, fX, fY, fZ)) # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y, Z) + Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y, Z) catch @test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y, Z) end @@ -267,6 +267,8 @@ end fQ = broadcast(f, fX, fY, fZ); Q = sparse(fQ) broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated @test_broken (@allocated broadcast!(f, Q, X, Y, Z)) == 0 + broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated + @test (@allocated broadcast!(f, Q, X, Y, Z)) <= 16 # the preceding test allocates 16 bytes in the entry point for broadcast!, but # none of the earlier tests of the same code path allocate. no allocation shows # up with --track-allocation=user. allocation shows up on the first line of the @@ -277,9 +279,9 @@ end @test broadcast!(f, Q, X, Y, Z) == sparse(broadcast!(f, fQ, fX, fY, fZ)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Q), spzeros((shapeX .- 1)...), Y, Z) + Base.Broadcast.check_broadcast_axes(axes(Q), spzeros((shapeX .- 1)...), Y, Z) catch @test_throws DimensionMismatch broadcast!(f, Q, spzeros((shapeX .- 1)...), Y, Z) end @@ -350,21 +352,11 @@ end @test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...)) @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) X = sparse(fX) # reset / warmup for @allocated test + # It'd be nice for this to be zero, but there's currently some constant overhead @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 - # This test (and the analog below) fails for three reasons: - # (1) In all cases, generating the closures that capture the scalar arguments - # results in allocation, not sure why. - # (2) In some cases, though _broadcast_eltype (which wraps _return_type) - # consistently provides the correct result eltype when passed the closure - # that incorporates the scalar arguments to broadcast (and, with #19667, - # is inferable, so the overall return type from broadcast is inferred), - # in some cases inference seems unable to determine the return type of - # direct calls to that closure. This issue causes variables in both the - # broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and - # the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have - # inferred type Any, resulting in allocation and lackluster performance. - # (3) The sparseargs... splat in the call above allocates a bit, but of course - # that issue is negligible and perhaps could be accounted for in the test. + X = sparse(fX) # reset / warmup for @allocated test + # And broadcasting over Transposes currently requires making a CSC copy, so we must account for that in the bounds + @test (@allocated broadcast!(*, X, sparseargs...)) <= (sum(x->isa(x, Transpose) ? Base.summarysize(x)*2+128 : 0, sparseargs) + 128) end end # test combinations at the limit of inference (eight arguments net) @@ -385,7 +377,8 @@ end @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) X = sparse(fX) # reset / warmup for @allocated test @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 - # please see the note a few lines above re. this @test_broken + X = sparse(fX) # reset / warmup for @allocated test + @test (@allocated broadcast!(*, X, sparseargs...)) <= 128 end end @@ -404,20 +397,12 @@ end structuredarrays = (D, B, T, S) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(sin, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(sin, fX))) - @test broadcast!(sin, Z, X) == sparse(broadcast(sin, fX)) - @test (Q = broadcast(cos, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(cos, fX))) - @test broadcast!(cos, Z, X) == sparse(broadcast(cos, fX)) - @test (Q = broadcast(*, s, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fX))) - @test broadcast!(*, Z, s, X) == sparse(broadcast(*, s, fX)) @test (Q = broadcast(+, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fV, fA, fX))) @test broadcast!(+, Z, V, A, X) == sparse(broadcast(+, fV, fA, fX)) @test (Q = broadcast(*, s, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fV, fA, fX))) @test broadcast!(*, Z, s, V, A, X) == sparse(broadcast(*, s, fV, fA, fX)) for (Y, fY) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(+, X, Y); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fX, fY))) @test broadcast!(+, Z, X, Y) == sparse(broadcast(+, fX, fY)) - @test (Q = broadcast(*, X, Y); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, fX, fY))) @test broadcast!(*, Z, X, Y) == sparse(broadcast(*, fX, fY)) end end @@ -426,9 +411,7 @@ end densearrays = (C, M) fD, fB = Array(D), Array(B) for X in densearrays - @test broadcast(+, D, X)::SparseMatrixCSC == sparse(broadcast(+, fD, X)) @test broadcast!(+, Z, D, X) == sparse(broadcast(+, fD, X)) - @test broadcast(*, s, B, X)::SparseMatrixCSC == sparse(broadcast(*, s, fB, X)) @test broadcast!(*, Z, s, B, X) == sparse(broadcast(*, s, fB, X)) @test broadcast(+, V, B, X)::SparseMatrixCSC == sparse(broadcast(+, fV, fB, X)) @test broadcast!(+, Z, V, B, X) == sparse(broadcast(+, fV, fB, X)) @@ -446,25 +429,6 @@ end @test A .+ ntuple(identity, N) isa Matrix end -@testset "broadcast! where the destination is a structured matrix" begin - # Where broadcast!'s destination is a structured matrix, broadcast! should fall back - # to the generic AbstractArray broadcast! code (at least for now). - N, p = 5, 0.4 - A = sprand(N, N, p) - sA = A + copy(A') - D = Diagonal(rand(N)) - B = Bidiagonal(rand(N), rand(N - 1), :U) - T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) - @test broadcast!(sin, copy(D), D) == Diagonal(sin.(D)) - @test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U) - @test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T)) - @test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A)) - @test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U) - @test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A)) - # SymTridiagonal (and similar symmetric matrix types) do not support setindex! - # off the diagonal, and so cannot serve as a destination for broadcast! -end - @testset "map[!] over combinations of sparse and structured matrices" begin N, p = 10, 0.4 A = sprand(N, N, p) @@ -476,16 +440,12 @@ end structuredarrays = (D, B, T, S) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = map(sin, X); Q isa SparseMatrixCSC && Q == sparse(map(sin, fX))) @test map!(sin, Z, X) == sparse(map(sin, fX)) - @test (Q = map(cos, X); Q isa SparseMatrixCSC && Q == sparse(map(cos, fX))) @test map!(cos, Z, X) == sparse(map(cos, fX)) @test (Q = map(+, A, X); Q isa SparseMatrixCSC && Q == sparse(map(+, fA, fX))) @test map!(+, Z, A, X) == sparse(map(+, fA, fX)) for (Y, fY) in zip(structuredarrays, fstructuredarrays) - @test (Q = map(+, X, Y); Q isa SparseMatrixCSC && Q == sparse(map(+, fX, fY))) @test map!(+, Z, X, Y) == sparse(map(+, fX, fY)) - @test (Q = map(*, X, Y); Q isa SparseMatrixCSC && Q == sparse(map(*, fX, fY))) @test map!(*, Z, X, Y) == sparse(map(*, fX, fY)) @test (Q = map(+, X, A, Y); Q isa SparseMatrixCSC && Q == sparse(map(+, fX, fA, fY))) @test map!(+, Z, X, A, Y) == sparse(map(+, fX, fA, fY)) diff --git a/test/bitarray.jl b/test/bitarray.jl index 572918a2dda77..806ecc6b4aa0f 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1014,6 +1014,41 @@ timesofar("unary arithmetic") @check_bit_operation broadcast(^, b1, 1im) Matrix{ComplexF64} @check_bit_operation broadcast(^, b1, 0x1*im) Matrix{ComplexF64} end + + @testset "Matrix/Vector" begin + b1 = bitrand(n1, n2) + b2 = bitrand(n1) + b3 = bitrand(n2) + + @check_bit_operation broadcast(&, b1, b2) BitMatrix + @check_bit_operation broadcast(&, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(&, b2, b1) BitMatrix + @check_bit_operation broadcast(&, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(|, b1, b2) BitMatrix + @check_bit_operation broadcast(|, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(|, b2, b1) BitMatrix + @check_bit_operation broadcast(|, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(xor, b1, b2) BitMatrix + @check_bit_operation broadcast(xor, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(xor, b2, b1) BitMatrix + @check_bit_operation broadcast(xor, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(+, b1, b2) Matrix{Int} + @check_bit_operation broadcast(+, b1, transpose(b3)) Matrix{Int} + @check_bit_operation broadcast(+, b2, b1) Matrix{Int} + @check_bit_operation broadcast(+, transpose(b3), b1) Matrix{Int} + @check_bit_operation broadcast(-, b1, b2) Matrix{Int} + @check_bit_operation broadcast(-, b1, transpose(b3)) Matrix{Int} + @check_bit_operation broadcast(-, b2, b1) Matrix{Int} + @check_bit_operation broadcast(-, transpose(b3), b1) Matrix{Int} + @check_bit_operation broadcast(*, b1, b2) BitMatrix + @check_bit_operation broadcast(*, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(*, b2, b1) BitMatrix + @check_bit_operation broadcast(*, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(/, b1, b2) Matrix{Float64} + @check_bit_operation broadcast(/, b1, transpose(b3)) Matrix{Float64} + @check_bit_operation broadcast(/, b2, b1) Matrix{Float64} + @check_bit_operation broadcast(/, transpose(b3), b1) Matrix{Float64} + end end timesofar("binary arithmetic") diff --git a/test/broadcast.jl b/test/broadcast.jl index 966fc1a5a9e22..35f127658ad6f 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -2,7 +2,7 @@ module TestBroadcastInternals -using Base.Broadcast: check_broadcast_indices, check_broadcast_shape, newindex, _bcs +using Base.Broadcast: check_broadcast_axes, check_broadcast_shape, newindex, _bcs using Base: OneTo using Test, Random @@ -19,22 +19,22 @@ using Test, Random @test_throws DimensionMismatch _bcs((-1:1, 2:6), (-1:1, 2:5)) @test_throws DimensionMismatch _bcs((-1:1, 2:5), (2, 2:5)) -@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4)) - -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,1)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(3)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), 1) -check_broadcast_indices((OneTo(3),OneTo(5)), 5, 2) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(2,5)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4,2)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(2)) -check_broadcast_indices((-1:1, 6:9), 1) +@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4)) + +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,1)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(3)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), 1) +check_broadcast_axes((OneTo(3),OneTo(5)), 5, 2) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(2,5)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4,2)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(2)) +check_broadcast_axes((-1:1, 6:9), 1) check_broadcast_shape((-1:1, 6:9), (-1:1, 6:9)) check_broadcast_shape((-1:1, 6:9), (-1:1, 1)) @@ -167,7 +167,9 @@ rt = Base.return_types(broadcast!, Tuple{Function, Array{Float64, 3}, Array{Floa @test length(rt) == 1 && rt[1] == Array{Float64, 3} # f.(args...) syntax (#15032) -let x = [1,3.2,4.7], y = [3.5, pi, 1e-4], α = 0.2342 +let x = [1, 3.2, 4.7], + y = [3.5, pi, 1e-4], + α = 0.2342 @test sin.(x) == broadcast(sin, x) @test sin.(α) == broadcast(sin, α) @test sin.(3.2) == broadcast(sin, 3.2) == sin(3.2) @@ -237,12 +239,12 @@ let x = sin.(1:10), a = [x] @test atan2.(x, cos.(x)) == atan2.(a..., cos.(x)) == broadcast(atan2, x, cos.(a...)) == broadcast(atan2, a..., cos.(a...)) @test ((args...)->cos(args[1])).(x) == cos.(x) == ((y,args...)->cos(y)).(x) end -@test atan2.(3,4) == atan2(3,4) == (() -> atan2(3,4)).() +@test atan2.(3, 4) == atan2(3, 4) == (() -> atan2(3, 4)).() # fusion with keyword args: let x = [1:4;] f17300kw(x; y=0) = x + y @test f17300kw.(x) == x - @test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1 + @test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1 == [2, 3, 4, 5] @test f17300kw.(sin.(x), y=1) == f17300kw.(sin.(x); y=1) == sin.(x) .+ 1 @test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1) end @@ -408,7 +410,7 @@ StrangeType18623(x,y) = (x,y) let f(A, n) = broadcast(x -> +(x, n), A) @test @inferred(f([1.0], 1)) == [2.0] - g() = (a = 1; Broadcast.combine_eltypes(x -> x + a, 1.0)) + g() = (a = 1; Broadcast.combine_eltypes(x -> x + a, (1.0,))) @test @inferred(g()) === Float64 end @@ -428,7 +430,7 @@ abstract type ArrayData{T,N} <: AbstractArray{T,N} end Base.getindex(A::ArrayData, i::Integer...) = A.data[i...] Base.setindex!(A::ArrayData, v::Any, i::Integer...) = setindex!(A.data, v, i...) Base.size(A::ArrayData) = size(A.data) -Base.broadcast_similar(f, ::Broadcast.ArrayStyle{A}, ::Type{T}, inds::Tuple, As...) where {A,T} = +Base.broadcast_similar(::Broadcast.ArrayStyle{A}, ::Type{T}, inds::Tuple, bc) where {A,T} = A(Array{T}(undef, length.(inds))) struct Array19745{T,N} <: ArrayData{T,N} @@ -488,14 +490,21 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 @testset "broadcasting for custom AbstractArray" begin a = randn(10) aa = Array19745(a) - @test a .+ 1 == @inferred(aa .+ 1) - @test a .* a' == @inferred(aa .* aa') + fadd(aa) = aa .+ 1 + fadd2(aa) = aa .+ 1 .* 2 + fprod(aa) = aa .* aa' + @test a .+ 1 == @inferred(fadd(aa)) + @test a .+ 1 .* 2 == @inferred(fadd2(aa)) + @test a .* a' == @inferred(fprod(aa)) @test isa(aa .+ 1, Array19745) + @test isa(aa .+ 1 .* 2, Array19745) @test isa(aa .* aa', Array19745) a1 = AD1(rand(2,3)) a2 = AD2(rand(2)) @test a1 .+ 1 isa AD1 @test a2 .+ 1 isa AD2 + @test a1 .+ 1 .* 2 isa AD1 + @test a2 .+ 1 .* 2 isa AD2 @test a1 .+ a2 isa Array @test a2 .+ a1 isa Array @test a1 .+ a2 .+ a1 isa Array @@ -504,6 +513,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2P(rand(2)) @test a1 .+ 1 isa AD1P @test a2 .+ 1 isa AD2P + @test a1 .+ 1 .* 2 isa AD1P + @test a2 .+ 1 .* 2 isa AD2P @test a1 .+ a2 isa AD1P @test a2 .+ a1 isa AD1P @test a1 .+ a2 .+ a1 isa AD1P @@ -512,6 +523,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2B(rand(2)) @test a1 .+ 1 isa AD1B @test a2 .+ 1 isa AD2B + @test a1 .+ 1 .* 2 isa AD1B + @test a2 .+ 1 .* 2 isa AD2B @test a1 .+ a2 isa AD1B @test a2 .+ a1 isa AD1B @test a1 .+ a2 .+ a1 isa AD1B @@ -520,6 +533,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2C(rand(2)) @test a1 .+ 1 isa AD1C @test a2 .+ 1 isa AD2C + @test a1 .+ 1 .* 2 isa AD1C + @test a2 .+ 1 .* 2 isa AD2C @test_throws ErrorException a1 .+ a2 end @@ -532,7 +547,7 @@ end # Test that broadcast's promotion mechanism handles closures accepting more than one argument. # (See issue #19641 and referenced issues and pull requests.) -let f() = (a = 1; Broadcast.combine_eltypes((x, y) -> x + y + a, 1.0, 1.0)) +let f() = (a = 1; Broadcast.combine_eltypes((x, y) -> x + y + a, (1.0, 1.0))) @test @inferred(f()) == Float64 end @@ -637,3 +652,52 @@ let n = 1 @test ceil.(Int, n ./ (1,)) == (1,) @test ceil.(Int, 1 ./ (1,)) == (1,) end + + +# lots of splatting! +let x = [[1, 4], [2, 5], [3, 6]] + y = .+(x..., .*(x..., x...)..., x[1]..., x[2]..., x[3]...) + @test y == [14463, 14472] + + z = zeros(2) + z .= .+(x..., .*(x..., x...)..., x[1]..., x[2]..., x[3]...) + @test z == Float64[14463, 14472] +end + +# Issue #21094 +@generated function foo21094(out, x) + quote + out .= x .+ x + out + end +end +@test foo21094([0.0], [1.0]) == [2.0] + +# Issue #22053 +struct T22053 + t +end +Broadcast.BroadcastStyle(::Type{T22053}) = Broadcast.Style{T22053}() +Broadcast.broadcast_axes(::T22053) = () +Broadcast.broadcastable(t::T22053) = t +function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{T22053}}) + all(x->isa(x, T22053), bc.args) && return 1 + return 0 +end +Base.:*(::T22053, ::T22053) = 2 +let x = T22053(1) + @test x*x == 2 + @test x.*x == 1 +end + +# Issue https://github.com/JuliaLang/julia/pull/25377#discussion_r159956996 +let X = Any[1,2] + X .= nothing + @test X[1] == X[2] == nothing +end + +# Ensure that broadcast styles with custom indexing work +let X = zeros(2, 3) + X .= (1, 2) + @test X == [1 1 1; 2 2 2] +end diff --git a/test/core.jl b/test/core.jl index ba9f4b9d3a61a..668221ed7bf0b 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1941,11 +1941,11 @@ test5884() # issue #5924 let - function Test() + function test5924() func = function () end func end - @test Test()() === nothing + @test test5924()() === nothing end # issue #6031 diff --git a/test/numbers.jl b/test/numbers.jl index b324aa7592728..b4a8910d365d5 100644 --- a/test/numbers.jl +++ b/test/numbers.jl @@ -2415,7 +2415,7 @@ Base.literal_pow(::typeof(^), ::PR20530, ::Val{p}) where {p} = 2 p = 2 @test x^p == 1 @test x^2 == 2 - @test [x,x,x].^2 == [2,2,2] + @test [x, x, x].^2 == [2, 2, 2] for T in (Float16, Float32, Float64, BigFloat, Int8, Int, BigInt, Complex{Int}, Complex{Float64}) for p in -4:4 v = eval(:($T(2)^$p)) @@ -2430,6 +2430,7 @@ Base.literal_pow(::typeof(^), ::PR20530, ::Val{p}) where {p} = 2 end @test PR20889(2)^3 == 5 @test [2,4,8].^-2 == [0.25, 0.0625, 0.015625] + @test [2, 4, 8].^-2 .* 4 == [1.0, 0.25, 0.0625] # nested literal_pow @test ℯ^-2 == exp(-2) ≈ inv(ℯ^2) ≈ (ℯ^-1)^2 ≈ sqrt(ℯ^-4) end module M20889 # do we get the expected behavior without importing Base.^? diff --git a/test/ranges.jl b/test/ranges.jl index c0ebdeee04c6e..9b32428dea766 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -477,15 +477,15 @@ end @test sum(0:0.1:10) == 505. end @testset "broadcasted operations with scalars" begin - @test broadcast(-, 1:3, 2) == -1:1 - @test broadcast(-, 1:3, 0.25) == 1-0.25:3-0.25 - @test broadcast(+, 1:3, 2) == 3:5 - @test broadcast(+, 1:3, 0.25) == 1+0.25:3+0.25 - @test broadcast(+, 1:2:6, 1) == 2:2:6 - @test broadcast(+, 1:2:6, 0.3) == 1+0.3:2:5+0.3 - @test broadcast(-, 1:2:6, 1) == 0:2:4 - @test broadcast(-, 1:2:6, 0.3) == 1-0.3:2:5-0.3 - @test broadcast(-, 2, 1:3) == 1:-1:-1 + @test broadcast(-, 1:3, 2) === -1:1 + @test broadcast(-, 1:3, 0.25) === 1-0.25:3-0.25 + @test broadcast(+, 1:3, 2) === 3:5 + @test broadcast(+, 1:3, 0.25) === 1+0.25:3+0.25 + @test broadcast(+, 1:2:6, 1) === 2:2:6 + @test broadcast(+, 1:2:6, 0.3) === 1+0.3:2:5+0.3 + @test broadcast(-, 1:2:6, 1) === 0:2:4 + @test broadcast(-, 1:2:6, 0.3) === 1-0.3:2:5-0.3 + @test broadcast(-, 2, 1:3) === 1:-1:-1 end @testset "operations between ranges and arrays" begin @test all(([1:5;] + (5:-1:1)) .== 6) @@ -551,27 +551,33 @@ end @test [0.0:prevfloat(0.1):0.3;] == [0.0, prevfloat(0.1), prevfloat(0.2), 0.3] @test [0.0:nextfloat(0.1):0.3;] == [0.0, nextfloat(0.1), nextfloat(0.2)] end -@testset "issue #7420 for type $T" for T = (Float32, Float64,), # BigFloat), - a = -5:25, - s = [-5:-1; 1:25; ], - d = 1:25, - n = -1:15 - - denom = convert(T, d) - strt = convert(T, a)/denom - Δ = convert(T, s)/denom - stop = convert(T, (a + (n - 1) * s)) / denom - vals = T[a:s:(a + (n - 1) * s); ] ./ denom - r = strt:Δ:stop - @test [r;] == vals - @test [range(strt, stop=stop, length=length(r));] == vals - n = length(r) - @test [r[1:n];] == [r;] - @test [r[2:n];] == [r;][2:end] - @test [r[1:3:n];] == [r;][1:3:n] - @test [r[2:2:n];] == [r;][2:2:n] - @test [r[n:-1:2];] == [r;][n:-1:2] - @test [r[n:-2:1];] == [r;][n:-2:1] + +function loop_range_values(::Type{T}) where T + for a = -5:25, + s = [-5:-1; 1:25; ], + d = 1:25, + n = -1:15 + + denom = convert(T, d) + strt = convert(T, a)/denom + Δ = convert(T, s)/denom + stop = convert(T, (a + (n - 1) * s)) / denom + vals = T[a:s:(a + (n - 1) * s); ] ./ denom + r = strt:Δ:stop + @test [r;] == vals + @test [range(strt, stop=stop, length=length(r));] == vals + n = length(r) + @test [r[1:n];] == [r;] + @test [r[2:n];] == [r;][2:end] + @test [r[1:3:n];] == [r;][1:3:n] + @test [r[2:2:n];] == [r;][2:2:n] + @test [r[n:-1:2];] == [r;][n:-1:2] + @test [r[n:-2:1];] == [r;][n:-2:1] + end +end + +@testset "issue #7420 for type $T" for T = (Float32, Float64,) # BigFloat), + loop_range_values(T) end @testset "issue #20373 (unliftable ranges with exact end points)" begin @@ -990,7 +996,10 @@ end for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0, range(1, stop=100, length=10), range(1f0, stop=100f0, length=10)) float_r = float(_r) - big_r = big.(_r) + big_r = broadcast(big, _r) + big_rdot = big.(_r) + @test big_rdot == big_r + @test typeof(big_r) == typeof(big_rdot) @test typeof(big_r).name === typeof(_r).name if eltype(_r) <: AbstractFloat @test isa(float_r, typeof(_r)) @@ -1217,6 +1226,22 @@ end @test map(BigFloat, x) === x end +@testset "broadcasting returns ranges" begin + x, r = 2, 1:5 + @test @inferred(x .+ r) === 3:7 + @test @inferred(r .+ x) === 3:7 + @test @inferred(r .- x) === -1:3 + @test @inferred(x .- r) === 1:-1:-3 + @test @inferred(x .* r) === 2:2:10 + @test @inferred(r .* x) === 2:2:10 + @test @inferred(r ./ x) === 0.5:0.5:2.5 + @test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) + @test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) + @test @inferred(x .\ r) === 0.5:0.5:2.5 + + @test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14 +end + @testset "Bad range calls" begin @test_throws ArgumentError range(1) @test_throws ArgumentError range(nothing) From c9a0a418e7876534b97c491bef7992b5dff4e7bb Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 24 Apr 2018 12:14:35 -0400 Subject: [PATCH 2/7] Rename Broadcast.make to Broadcast.broadcasted --- base/broadcast.jl | 96 +++++++++++----------- doc/src/manual/interfaces.md | 16 ++-- src/julia-syntax.scm | 2 +- stdlib/LinearAlgebra/src/uniformscaling.jl | 6 +- 4 files changed, 60 insertions(+), 60 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index ef81b60f89f4b..3a8be92fbf0ad 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -689,7 +689,7 @@ julia> string.(("one","two","three","four"), ": ", 1:4) ``` """ -broadcast(f::Tf, As...) where {Tf} = copy(instantiate(make(f, As...))) +broadcast(f::Tf, As...) where {Tf} = copy(instantiate(broadcasted(f, As...))) # special cases defined for performance @inline broadcast(f, x::Number...) = f(x...) @@ -704,7 +704,7 @@ Note that `dest` is only used to store the result, and does not supply arguments to `f` unless it is also listed in the `As`, as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. """ -broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, make(f, As...)); dest) +broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, broadcasted(f, As...)); dest) """ Broadcast.materialize(bc) @@ -926,57 +926,57 @@ _longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A ## scalar-range broadcast operations ## # DefaultArrayStyle and \ are not available at the time of range.jl -make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r)) # For #18336 we need to prevent promotion of the step type: -make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T = +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T = StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T = +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T = StepRangeLen{typeof(x+T(r.ref))}(x + r.ref, r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRange(r.start + x, r.stop + x, length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r)) -make(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 - -make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T = +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRange(r.start + x, r.stop + x, length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 + +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T = StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T = +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T = StepRangeLen{typeof(x-T(r.ref))}(x - r.ref, -r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRange(r.start - x, r.stop - x, length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r)) -make(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRange(r.start - x, r.stop - x, length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 -make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) # separate in case of noncommutative multiplication -make(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) -make(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} = +broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} = StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) -make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r)) -make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) +broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) -make(::DefaultArrayStyle{1}, ::typeof(big), r::UnitRange) = big(r.start):big(last(r)) -make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) -make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) -make(::DefaultArrayStyle{1}, ::typeof(big), r::LinRange) = LinRange(big(r.start), big(r.stop), length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(big), r::UnitRange) = big(r.start):big(last(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) +broadcasted(::DefaultArrayStyle{1}, ::typeof(big), r::LinRange) = LinRange(big(r.start), big(r.stop), length(r)) """ @@ -1186,26 +1186,26 @@ macro __dot__(x) esc(__dot__(x)) end -@inline make_kwsyntax(f, args...; kwargs...) = make((args...)->f(args...; kwargs...), args...) -@inline function make(f, args...) +@inline broadcasted_kwsyntax(f, args...; kwargs...) = broadcasted((args...)->f(args...; kwargs...), args...) +@inline function broadcasted(f, args...) args′ = map(broadcastable, args) - make(combine_styles(args′...), f, args′...) + broadcasted(combine_styles(args′...), f, args′...) end # Due to the current Type{T}/DataType specialization heuristics within Tuples, -# the totally generic varargs make(f, args...) method above loses Type{T}s in +# the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in # mapping broadcastable across the args. These additional methods with explicit # arguments ensure we preserve Type{T}s in the first or second argument position. -@inline function make(f, arg1, args...) +@inline function broadcasted(f, arg1, args...) arg1′ = broadcastable(arg1) args′ = map(broadcastable, args) - make(combine_styles(arg1′, args′...), f, arg1′, args′...) + broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...) end -@inline function make(f, arg1, arg2, args...) +@inline function broadcasted(f, arg1, arg2, args...) arg1′ = broadcastable(arg1) arg2′ = broadcastable(arg2) args′ = map(broadcastable, args) - make(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) + broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) end -@inline make(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args) +@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args) end # module diff --git a/doc/src/manual/interfaces.md b/doc/src/manual/interfaces.md index e237818334ed7..620bb4f5b9063 100644 --- a/doc/src/manual/interfaces.md +++ b/doc/src/manual/interfaces.md @@ -449,8 +449,8 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f | `Base.copy(bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast` | | `Base.copyto!(dest, bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast!`, specializing on `DestStyle` | | `Base.copyto!(dest::DestType, bc::Broadcasted{Nothing})` | Custom implementation of `broadcast!`, specializing on `DestType` | -| `Base.Broadcast.make(f, args...)` | Override the default lazy behavior within a fused expression | -| `Base.Broadcast.instantiate(bc::Broadcasted{DestStyle})` | Override the computation of the wrapper's axes and indexers | +| `Base.Broadcast.broadcasted(f, args...)` | Override the default lazy behavior within a fused expression | +| `Base.Broadcast.instantiate(bc::Broadcasted{DestStyle})` | Override the computation of the lazy broadcast's axes | [Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by "dot" operations like `A .+ b` or `f.(x, y)`. Any object that has [`axes`](@ref) and supports @@ -615,22 +615,22 @@ need or want to evaluate `x .* (x .+ 1)` as if it had been written `broadcast(*, x, broadcast(+, x, 1))`, where the inner operation is evaluated before tackling the outer operation. This sort of eager operation is directly supported by a bit of indirection; instead of directly constructing `Broadcasted` objects, Julia lowers the -fused expression `x .* (x .+ 1)` to `Broadcast.make(*, x, Broadcast.make(+, x, 1))`. Now, -by default, `make` just calls the `Broadcasted` constructor to create the lazy representation +fused expression `x .* (x .+ 1)` to `Broadcast.broadcasted(*, x, Broadcast.broadcasted(+, x, 1))`. Now, +by default, `broadcasted` just calls the `Broadcasted` constructor to create the lazy representation of the fused expression tree, but you can choose to override it for a particular combination of function and arguments. As an example, the builtin `AbstractRange` objects use this machinery to optimize pieces of broadcasted expressions that can be eagerly evaluated purely in terms of the start, step, and length (or stop) instead of computing every single element. Just like all the -other machinery, `make` also computes and exposes the combined broadcast style of its -arguments, so instead of specializing on `make(f, args...)`, you can specialize on -`make(::DestStyle, f, args...)` for any combination of style, function, and arguments. +other machinery, `broadcasted` also computes and exposes the combined broadcast style of its +arguments, so instead of specializing on `broadcasted(f, args...)`, you can specialize on +`broadcasted(::DestStyle, f, args...)` for any combination of style, function, and arguments. For example, the following definition supports the negation of ranges: ```julia -make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) ``` ### [Extending in-place broadcasting](@id extending-in-place-broadcast) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index ea7956bb1f369..9b0356533d0d2 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1691,7 +1691,7 @@ (kws (car kws+args)) (kws (if (null? kws) kws (list (cons 'parameters kws)))) (args (map dot-to-fuse (cdr kws+args))) - (make `(call (|.| (top Broadcast) ,(if (null? kws) ''make ''make_kwsyntax)) ,@kws ,f ,@args))) + (make `(call (|.| (top Broadcast) ,(if (null? kws) ''broadcasted ''broadcasted_kwsyntax)) ,@kws ,f ,@args))) (if top (cons 'fuse make) make))) (if (and (pair? e) (eq? (car e) '|.|)) (let ((f (cadr e)) (x (caddr e))) diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index 5c2ba8720111f..6c04450da0b07 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -208,10 +208,10 @@ end \(x::Number, J::UniformScaling) = UniformScaling(x\J.λ) -Broadcast.make(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) -Broadcast.make(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) +Broadcast.broadcasted(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) +Broadcast.broadcasted(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) -Broadcast.make(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) +Broadcast.broadcasted(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) ==(J1::UniformScaling,J2::UniformScaling) = (J1.λ == J2.λ) From 4cd10d3de8c20c387f71bee4f73a836afc26ff7a Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 24 Apr 2018 13:34:24 -0400 Subject: [PATCH 3/7] Test and fix #26127 --- test/broadcast.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/broadcast.jl b/test/broadcast.jl index 35f127658ad6f..779f6371c2fed 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -701,3 +701,10 @@ let X = zeros(2, 3) X .= (1, 2) @test X == [1 1 1; 2 2 2] end + +# Issue #26127: multiple splats in a fused dot-expression +let f(args...) = *(args...) + x, y, z = (1,2), 3, (4, 5) + @test f.(x..., y, z...) == broadcast(f, x..., y, z...) == 120 + @test f.(x..., f.(x..., y, z...), y, z...) == broadcast(f, x..., broadcast(f, x..., y, z...), y, z...) == 120*120 +end From 8a2d88a493ad622be2a819e9013f80f1c3fbded6 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 24 Apr 2018 15:39:01 -0400 Subject: [PATCH 4/7] Remove hack that avoided materialize in fallback broadcast For reasons beyond my comprehensions, this previously failed to inline despite the at-inline. For reasons that are also beyond my comprehension, this hack is no longer necessary. --- base/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 3a8be92fbf0ad..628190f41b5d2 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -689,7 +689,7 @@ julia> string.(("one","two","three","four"), ": ", 1:4) ``` """ -broadcast(f::Tf, As...) where {Tf} = copy(instantiate(broadcasted(f, As...))) +broadcast(f::Tf, As...) where {Tf} = materialize(broadcasted(f, As...)) # special cases defined for performance @inline broadcast(f, x::Number...) = f(x...) From 75798f949f826d3c90b336150fd40e4fec7c27c3 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Wed, 25 Apr 2018 01:54:42 -0400 Subject: [PATCH 5/7] Remove the last of the base broadcast methods... by implementing a preprocessing step for chunked bitarray broadcast that converts boolean-only functions to ones that support chunk-wise operation. Also add a number of tests for this codepath --- base/bitarray.jl | 1 - base/broadcast.jl | 25 +++++++++++++++++-------- base/sysimg.jl | 7 ------- test/bitarray.jl | 22 ++++++++++++++++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/base/bitarray.jl b/base/bitarray.jl index 898980b92d4ac..be4600713110b 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -1095,7 +1095,6 @@ function (-)(B::BitArray) end return A end -broadcast(::typeof(sign), B::BitArray) = copy(B) """ flipbits!(B::BitArray{N}) -> BitArray{N} diff --git a/base/broadcast.jl b/base/broadcast.jl index 628190f41b5d2..ffca52e64de07 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -5,8 +5,8 @@ module Broadcast using .Base.Cartesian using .Base: Indices, OneTo, linearindices, tail, to_shape, isoperator, promote_typejoin, _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias -import .Base: broadcast, broadcast!, copy, copyto! -export BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable, +import .Base: copy, copyto! +export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable, broadcast_getindex, broadcast_setindex!, dotview, @__dot__ ### Objects with customized broadcasting behavior should declare a BroadcastStyle @@ -813,7 +813,7 @@ end # Performance optimization: for BitArray outputs, we cache the result # in a "small" Vector{Bool}, and then copy in chunks into the output -function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) +@inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc) tmp = Vector{Bool}(undef, bitcache_size) @@ -839,12 +839,13 @@ end # For some BitArray operations, we can work at the level of chunks. The trivial # implementation just walks over the UInt64 chunks in a linear fashion. # This requires three things: -# 1. The function must be known to work at the level of chunks -# 2. The only arrays involved must be BitArrays or scalars +# 1. The function must be known to work at the level of chunks (or can be converted to do so) +# 2. The only arrays involved must be BitArrays or scalar Bools # 3. There must not be any broadcasting beyond scalar — all array sizes must match # We could eventually allow for all broadcasting and other array types, but that # requires very careful consideration of all the edge effects. -const ChunkableOp = Union{typeof(&), typeof(|), typeof(xor), typeof(~)} +const ChunkableOp = Union{typeof(&), typeof(|), typeof(xor), typeof(~), typeof(identity), + typeof(!), typeof(*), typeof(==)} # these are convertable to chunkable ops by liftfuncs const BroadcastedChunkableOp{Style<:Union{Nothing,BroadcastStyle}, Axes, F<:ChunkableOp, Args<:Tuple} = Broadcasted{Style,Axes,F,Args} ischunkedbroadcast(R, bc::BroadcastedChunkableOp) = ischunkedbroadcast(R, bc.args) ischunkedbroadcast(R, args) = false @@ -853,6 +854,14 @@ ischunkedbroadcast(R, args::Tuple{<:Bool,Vararg{Any}}) = ischunkedbroadcast(R, t ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischunkedbroadcast(R, args[1]) && ischunkedbroadcast(R, tail(args)) ischunkedbroadcast(R, args::Tuple{}) = true +# Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps +liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)∘(xor), map(liftfuncs, bc.args), bc.axes) +liftfuncs(x) = x + liftchunks(::Tuple{}) = () liftchunks(args::Tuple{<:BitArray,Vararg{Any}}) = (args[1].chunks, liftchunks(tail(args))...) # Transform scalars to repeated scalars the size of a chunk @@ -860,9 +869,9 @@ liftchunks(args::Tuple{<:Bool,Vararg{Any}}) = (ifelse(args[1], typemax(UInt64), ithchunk(i) = () Base.@propagate_inbounds ithchunk(i, c::Vector{UInt64}, args...) = (c[i], ithchunk(i, args...)...) Base.@propagate_inbounds ithchunk(i, b::UInt64, args...) = (b, ithchunk(i, args...)...) -function chunkedcopyto!(dest::BitArray, bc::Broadcasted) +@inline function chunkedcopyto!(dest::BitArray, bc::Broadcasted) isempty(dest) && return dest - f = flatten(bc) + f = flatten(liftfuncs(bc)) args = liftchunks(f.args) dc = dest.chunks @simd for i in eachindex(dc) diff --git a/base/sysimg.jl b/base/sysimg.jl index f5ad1d0ce42ee..662f6c217da25 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -135,13 +135,6 @@ using .Checked # vararg Symbol constructor Symbol(x...) = Symbol(string(x...)) -# Define the broadcast function, which is mostly implemented in -# broadcast.jl, so that we can overload broadcast methods for -# specific array types etc. -# --Here, just define fallback routines for broadcasting with no arguments -broadcast(f) = f() -broadcast!(f, X::AbstractArray) = (@inbounds for I in eachindex(X); X[I] = f(); end; X) - # array structures include("indices.jl") include("array.jl") diff --git a/test/bitarray.jl b/test/bitarray.jl index 806ecc6b4aa0f..f88a5d8a02913 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1535,3 +1535,25 @@ timesofar("I/O") b[:] = view(trues(10), [1,3,7]) @test b == trues(3) end + +@testset "chunked broadcast" begin + for (f,g,h) in ((&,|,!),(*,xor,identity),(|,xor,sign),(&,&,~),(|,|,!)) + fg = (A, B, C)->f.(A, g.(B, C)) + fgh = (A, B, C)->f.(A, g.(B, h.(C))) + for n in (1, 63, 64, 65, 127, 128, 129) + for ((A,B,C),T) in ((bitrand.((n,n,n)), BitVector), (bitrand.((n,n,n), 2), BitMatrix)) + @check_bit_operation broadcast(f, A) T + @check_bit_operation broadcast(g, A) T + @check_bit_operation broadcast(h, A) T + @check_bit_operation fg(A, B, C) T + @check_bit_operation fg(true, B, C) T + @check_bit_operation fg(A, false, C) T + @check_bit_operation fg(A, B, true) T + @check_bit_operation fgh(A, B, C) T + @check_bit_operation fgh(true, B, C) T + @check_bit_operation fgh(A, false, C) T + @check_bit_operation fgh(A, B, true) T + end + end + end +end From edd4dcf7210b5378182a8b3d8faa23417b55bc0a Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Wed, 25 Apr 2018 10:07:15 -0400 Subject: [PATCH 6/7] Correct NEWS.md --- NEWS.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0da63c9973120..67fc1343e9cbc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -414,11 +414,12 @@ This section lists changes that do not have deprecation warnings. now returns cartesian indices for multidimensional arrays (see below, [#25532]). * Broadcasting operations are no longer fused into a single operation by Julia's parser. - Instead, a lazy `Broadcasted` wrapper is created, and the parser will call - `copy(bc::Broadcasted)` or `copyto!(dest, bc::Broadcasted)` + Instead, a lazy `Broadcasted` object is created to represent the fused expression and + then realized with `copy(bc::Broadcasted)` or `copyto!(dest, bc::Broadcasted)` to evaluate the wrapper. Consequently, package authors generally need to specialize - `copy` and `copyto!` methods rather than `broadcast` and `broadcast!`. - See the [Interfaces chapter](https://docs.julialang.org/en/latest/manual/interfaces/#Interfaces-1) + `copy` and `copyto!` methods rather than `broadcast` and `broadcast!`. This also allows + for more customization and control of fused broadcasts. See the + [Interfaces chapter](https://docs.julialang.org/en/latest/manual/interfaces/#man-interfaces-broadcasting-1) for more information. * `find` has been renamed to `findall`. `findall`, `findfirst`, `findlast`, `findnext` From 18ad6a8801eefd7a95987253dcb7c95db5528f74 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Wed, 25 Apr 2018 11:43:26 -0400 Subject: [PATCH 7/7] Change lowering of `A .= x` to match fused RHS cases Now instead of lowering to `broadcast!(identity,A,x)`, it lowers to `Broadcast.materialize!(A, Broadcast.broadcasted(identity, x))`, with all names resolved in the top module. --- src/julia-syntax.scm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 9b0356533d0d2..13248257b8646 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1721,7 +1721,7 @@ ; expanded to something else (like a getfield) (if (null? lhs) (expand-forms e) - (expand-forms `(call (top broadcast!) (top identity) ,lhs-view ,e)))))) + (expand-forms `(call (|.| (top Broadcast) 'materialize!) ,lhs-view (call (|.| (top Broadcast) 'broadcasted) (top identity) ,e))))))) (define (expand-where body var)