Skip to content

Commit

Permalink
Fix value_type and key_type for Tuple (#59)
Browse files Browse the repository at this point in the history
* Fix value_type and key_type for Tuple

* Add tests

* Add comment

* Remove iszero fix
  • Loading branch information
blegat authored Aug 26, 2024
1 parent ba3697f commit bbcd649
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Translate coefficients `cfs` in `source::AbstractBasis` to basis
function coeffs(cfs, source::AbstractBasis, target::AbstractBasis)
source === target && return cfs
source == target && return cfs
res = zero_coeffs(valtype(cfs), target)
res = zero_coeffs(value_type(cfs), target)
return coeffs!(res, cfs, source, target)
end

Expand All @@ -86,7 +86,7 @@ Return `A' * cfs` where `A` is the linear map applied by
function adjoint_coeffs(cfs, source::AbstractBasis, target::AbstractBasis)
source === target && return cfs
source == target && return cfs
res = zero_coeffs(valtype(cfs), source)
res = zero_coeffs(value_type(cfs), source)
return adjoint_coeffs!(res, cfs, source, target)
end

Expand Down
22 changes: 14 additions & 8 deletions src/coefficients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,27 @@ provided based on random indexing. Additionally one needs to provide:
abstract type AbstractCoefficients{K,V} end

key_type(::Type{<:AbstractCoefficients{K}}) where {K} = K
Base.valtype(::Type{<:AbstractCoefficients{K,V}}) where {K,V} = V
value_type(::Type{<:AbstractCoefficients{K,V}}) where {K,V} = V
key_type(b::AbstractCoefficients) = key_type(typeof(b))
Base.valtype(b::AbstractCoefficients) = valtype(typeof(b))
value_type(b::AbstractCoefficients) = value_type(typeof(b))

key_type(b) = keytype(b)
# `keytype(::Type{SparseVector{V,K}})` is not defined so it falls
# back to `keytype{::Type{<:AbstractArray})` which returns `Int`.
key_type(::Type{SparseArrays.SparseVector{V,K}}) where {V,K} = K
key_type(v::SparseArrays.SparseVector) = key_type(typeof(v))
key_type(::Tuple) = Int

value_type(coeffs) = valtype(coeffs)
# `valtype` is not defined for `Tuple` so we need to define
# our own function `value_type` as defining `valtype` for
# tuples would be type piracy
value_type(::Type{NTuple{N,T}}) where {N,T} = T
value_type(::NTuple{N,T}) where {N,T} = T

Base.iszero(ac::AbstractCoefficients) = isempty(keys(ac))

Base.similar(ac::AbstractCoefficients) = similar(ac, valtype(ac))
Base.similar(ac::AbstractCoefficients) = similar(ac, value_type(ac))

similar_type(::Type{<:Vector}, ::Type{T}) where {T} = Vector{T}
similar_type(::Type{<:SparseArrays.SparseVector{C,I}}, ::Type{T}) where {C,I,T} = SparseArrays.SparseVector{T,I}
Expand Down Expand Up @@ -96,7 +104,7 @@ end

function LinearAlgebra.dot(ac::AbstractCoefficients, bc::AbstractCoefficients)
if isempty(values(ac)) || isempty(values(bc))
return zero(MA.promote_sum_mul(valtype(ac), valtype(bc)))
return zero(MA.promote_sum_mul(value_type(ac), value_type(bc)))
else
return sum(c * star(bc[i]) for (i, c) in nonzero_pairs(ac))
end
Expand All @@ -105,7 +113,7 @@ end
function LinearAlgebra.dot(w::AbstractVector, ac::AbstractCoefficients)
@assert key_type(ac) <: Integer
if isempty(values(ac))
return zero(MA.promote_sum_mul(eltype(w), valtype(ac)))
return zero(MA.promote_sum_mul(eltype(w), value_type(ac)))
else
return sum(w[i] * star(v) for (i, v) in nonzero_pairs(ac))
end
Expand All @@ -114,7 +122,7 @@ end
function LinearAlgebra.dot(ac::AbstractCoefficients, w::AbstractVector)
@assert key_type(ac) <: Integer
if isempty(values(ac))
return zero(MA.promote_sum_mul(eltype(w), valtype(ac)))
return zero(MA.promote_sum_mul(eltype(w), value_type(ac)))
else
return sum(v * star(w[i]) for (i, v) in nonzero_pairs(ac))
end
Expand Down Expand Up @@ -260,5 +268,3 @@ function MA.operate_to!(
end
return res
end


2 changes: 1 addition & 1 deletion src/diracs_augmented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ aug(cfs::Any) = sum(values(cfs))
aug(a::AlgebraElement) = aug(coeffs(a))

function aug(ac::AbstractCoefficients)
isempty(keys(ac)) && return zero(valtype(ac))
isempty(keys(ac)) && return zero(value_type(ac))
return sum(c * aug(x) for (x, c) in nonzero_pairs(ac))
end

Expand Down
2 changes: 1 addition & 1 deletion src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
function _show(io::IO, mime, a::AlgebraElement)
A = parent(a)
if iszero(a)
T = valtype(coeffs(a))
T = value_type(coeffs(a))
_coeff_elt_print(io, mime, zero(T), first(basis(A)))
else
_first = true
Expand Down
8 changes: 4 additions & 4 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function Base.getindex(sc::SparseCoefficients{K}, key::K) where {K}
return zero(v)
end
else
return zero(valtype(sc))
return zero(value_type(sc))
end
end

Expand Down Expand Up @@ -112,7 +112,7 @@ function similar_type(::Type{SparseCoefficients{K,V,Vk,Vv}}, ::Type{T}) where {K
return SparseCoefficients{K,T,_similar_type(Vk, K),_similar_type(Vv, T)}
end

function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T}
function Base.similar(s::SparseCoefficients, ::Type{T} = value_type(s)) where {T}
return SparseCoefficients(collect(s.basis_elements), _similar(s.values, T))
end

Expand All @@ -127,13 +127,13 @@ end

### temporary convenience? how to handle this?
function __prealloc(X::SparseCoefficients, a::Number, op)
T = MA.promote_operation(op, valtype(X), typeof(a))
T = MA.promote_operation(op, value_type(X), typeof(a))
return similar(X, T)
end

function __prealloc(X::SparseCoefficients, Y::SparseCoefficients, op)
# this is not even correct for op = *
T = MA.promote_operation(op, valtype(X), valtype(Y))
T = MA.promote_operation(op, value_type(X), value_type(Y))
return similar(X, T)
end

Expand Down
6 changes: 3 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct AlgebraElement{A,T,V} <: MA.AbstractMutable
end

Base.parent(a::AlgebraElement) = a.parent
Base.eltype(::Type{A}) where {A<:AlgebraElement} = valtype(MA.promote_operation(coeffs, A))
Base.eltype(::Type{A}) where {A<:AlgebraElement} = value_type(MA.promote_operation(coeffs, A))
Base.eltype(a::AlgebraElement) = eltype(typeof(a))
function MA.promote_operation(::typeof(coeffs), ::Type{AlgebraElement{A,T,V}}) where {A,T,V}
return V
Expand All @@ -53,14 +53,14 @@ basis(a::AlgebraElement) = basis(parent(a))

function AlgebraElement(coeffs, A::AbstractStarAlgebra)
_sanity_checks(coeffs, A)
return AlgebraElement{typeof(A),valtype(coeffs),typeof(coeffs)}(coeffs, A)
return AlgebraElement{typeof(A),value_type(coeffs),typeof(coeffs)}(coeffs, A)
end

function AlgebraElement(
coeffs::SparseCoefficients{T},
A::AbstractStarAlgebra{O,T},
) where {O,T}
return AlgebraElement{typeof(A),valtype(coeffs),typeof(coeffs)}(coeffs, A)
return AlgebraElement{typeof(A),value_type(coeffs),typeof(coeffs)}(coeffs, A)
end

### constructing elements
Expand Down
13 changes: 13 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct DummyBasis{T} <: SA.ExplicitBasis{T,Int}
elements::Vector{T}
end

Base.length(b::DummyBasis) = length(b.elements)
Base.getindex(b::DummyBasis, i::Int) = b.elements[i]

@testset "Basic tests" begin
b = DummyBasis(Irrational[π, ℯ])
a = StarAlgebra(nothing, b)
s(i) = sprint(show, MIME"text/plain"(), i)
@test sprint(show, AlgebraElement([2, -1], a)) == "$(s(π)) - 1·$(s(ℯ))"
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include("test_example_words.jl")
include("test_example_acoeffs.jl")

@testset "StarAlgebras" begin
include("basic.jl")
# proof of concept
using PermutationGroups
include("perm_grp_algebra.jl")
Expand Down

0 comments on commit bbcd649

Please sign in to comment.