From 3b0a5dccd467021652aca5ffe58df86e9459c278 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta Date: Wed, 8 Mar 2017 23:49:49 -0600 Subject: [PATCH] Make broadcast over sparse and rowvecs sparse --- base/sparse/higherorderfns.jl | 15 +++++++++------ test/sparse/higherorderfns.jl | 13 ++++++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 6ef9e4270063d..abc148485a195 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -6,10 +6,11 @@ module HigherOrderFns # particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present. import Base: map, map!, broadcast, broadcast! import Base.Broadcast: _containertype, promote_containertype, - broadcast_indices, broadcast_c, broadcast_c! + broadcast_indices, broadcast_c, broadcast_c! using Base: front, tail, to_shape -using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype +using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, + AbstractSparseMatrix, AbstractSparseArray, indtype # This module is organized as follows: # (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for @@ -1063,6 +1064,7 @@ promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = PromoteToSpa _spcontainertype(x) = _containertype(x) _spcontainertype(::Type{<:Vector}) = Vector _spcontainertype(::Type{<:Matrix}) = Matrix +_spcontainertype(::Type{<:RowVector}) = Matrix _spcontainertype(::Type{<:Ref}) = AbstractArray _spcontainertype(::Type{<:AbstractArray}) = AbstractArray # need the following two methods to override the immediately preceding method @@ -1105,10 +1107,11 @@ promote_spcontainertype(::FunnelToSparseBC, ::FunnelToSparseBC) = PromoteToSpars @inline spbroadcast_c!{N}(f, ::Type{AbstractSparseArray}, ::Type{AbstractArray}, C, B, As::Vararg{Any,N}) = broadcast_c!(f, Array, Array, C, B, As...) -@inline _sparsifystructured(A::AbstractSparseArray) = A -@inline _sparsifystructured(M::StructuredMatrix) = SparseMatrixCSC(M) -@inline _sparsifystructured(M::Matrix) = SparseMatrixCSC(M) -@inline _sparsifystructured(V::Vector) = SparseVector(V) +@inline _sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M) +@inline _sparsifystructured(V::AbstractVector) = SparseVector(V) +@inline _sparsifystructured(M::AbstractSparseMatrix) = SparseMatrixCSC(M) +@inline _sparsifystructured(V::AbstractSparseVector) = SparseVector(V) +@inline _sparsifystructured(S::SparseVecOrMat) = S @inline _sparsifystructured(x) = x diff --git a/test/sparse/higherorderfns.jl b/test/sparse/higherorderfns.jl index a9cbd40de46e3..32b9de28c3dd2 100644 --- a/test/sparse/higherorderfns.jl +++ b/test/sparse/higherorderfns.jl @@ -403,16 +403,19 @@ end densearrays = (C, M) fD, fB = Array(D), Array(B) for X in densearrays - @test (Q = broadcast(+, D, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fD, X))) + @test broadcast(+, D, X)::SparseMatrixCSC == sparse(broadcast(+, fD, X)) @test broadcast!(+, Z, D, X) == sparse(broadcast(+, fD, X)) - @test (Q = broadcast(*, s, B, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fB, X))) + @test broadcast(*, s, B, X)::SparseMatrixCSC == sparse(broadcast(*, s, fB, X)) @test broadcast!(*, Z, s, B, X) == sparse(broadcast(*, s, fB, X)) - @test (Q = broadcast(+, V, B, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fV, fB, X))) + @test broadcast(+, V, B, X)::SparseMatrixCSC == sparse(broadcast(+, fV, fB, X)) @test broadcast!(+, Z, V, B, X) == sparse(broadcast(+, fV, fB, X)) - @test (Q = broadcast(+, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fV, fA, X))) + @test broadcast(+, V, A, X)::SparseMatrixCSC == sparse(broadcast(+, fV, fA, X)) @test broadcast!(+, Z, V, A, X) == sparse(broadcast(+, fV, fA, X)) - @test (Q = broadcast(*, s, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fV, fA, X))) + @test broadcast(*, s, V, A, X)::SparseMatrixCSC == sparse(broadcast(*, s, fV, fA, X)) @test broadcast!(*, Z, s, V, A, X) == sparse(broadcast(*, s, fV, fA, X)) + # Issue #20954 combinations of sparse arrays and RowVectors + @test broadcast(+, A, X')::SparseMatrixCSC == sparse(broadcast(+, fA, X')) + @test broadcast(*, V, X')::SparseMatrixCSC == sparse(broadcast(*, fV, X')) end end