Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jul 3, 2024
1 parent f87c542 commit 2d38e15
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,54 @@ end
################
# Broadcasting #
################
# Inspired from `JuMP.Containers.SparseAxisArray`

struct BroadcastStyle{K} end
struct BroadcastStyle{K} <: Broadcast.BroadcastStyle end

Base.broadcastable(sc::SparseCoefficients) = sc
Base.Broadcast.BroadcastStyle(::Type{<:SparseCoefficients{K}}) where {K} = BroadcastStyle{K}()
Base.BroadcastStyle(::Type{<:SparseCoefficients{K}}) where {K} = BroadcastStyle{K}()
# Disallow mixing broadcasts.
function Base.BroadcastStyle(::BroadcastStyle, ::Base.BroadcastStyle)
return throw(
ArgumentError(
"Cannot broadcast `StarAlgebras.SparseCoefficients` with" *
" another array of different type",
),
)
end

# Allow broadcasting over scalars.
function Base.BroadcastStyle(
style::BroadcastStyle,
::Base.Broadcast.DefaultArrayStyle{0},
)
return style
end

# Used for broadcasting
Base.axes(sc::SparseCoefficients) = (sc.basis_elements,)
#Base.Broadcast.BroadcastStyle(::Type{<:SparseCoefficients{K}}) where {K} = SparseArrays.HigherOrderFns.SparseVecStyle()
#SparseArrays.HigherOrderFns.nonscalararg(::SparseCoefficients) = true

# `_get_arg` and `getindex` are inspired from `JuMP.Containers.SparseAxisArray`
_getindex(x::SparseCoefficients, index) = getindex(x, index)
_getindex(x::Any, ::Any) = x
_getindex(x::Ref, ::Any) = x[]

function _get_arg(args::Tuple, index)
return (_getindex(first(args), index), _get_arg(Base.tail(args), index)...)
end
_get_arg(::Tuple{}, _) = ()

function Base.getindex(bc::Broadcast.Broadcasted{<:BroadcastStyle}, index)
return bc.f(_get_arg(bc.args, index)...)
end

function Base.similar(bc::Broadcast.Broadcasted{<:BroadcastStyle}, ::Type{T}) where {T}
return similar(_first_sparse_coeffs(bc.args...), T)
end

_first_sparse_coeffs(c::SparseCoefficients, args...) = c
_first_sparse_coeffs(_, args...) = _first_sparse_coeffs(args...)

function Base.zero(sc::SparseCoefficients)
return SparseCoefficients(empty(keys(sc)), empty(values(sc)))
Expand All @@ -66,7 +107,7 @@ function similar_type(::Type{SparseCoefficients{K,V,Vk,Vv}}, ::Type{T}) where {K
end

function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T}
return SparseCoefficients(_similar(s.basis_elements), _similar(s.values, T))
return SparseCoefficients(collect(s.basis_elements), _similar(s.values, T))
end

function MA.mutability(
Expand Down

0 comments on commit 2d38e15

Please sign in to comment.