Skip to content

Commit

Permalink
Assume mutability for canonical
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed May 27, 2024
1 parent 35539ee commit a802e6b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 17 deletions.
19 changes: 7 additions & 12 deletions src/coefficients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,25 @@ If `ac` can be brought to canonical form in-place one has to implement
otherwise `canonical(ac)` needs to be implemented.
"""
canonical(ac::AbstractCoefficients) = ac
MA.operate(::typeof(canonical), x) = canonical(x) # fallback?
function canonical end
function MA.promote_operation(::typeof(canonical), ::Type{C}) where {C}
return C
end

# example implementation for vectors
function MA.mutability(
::Type{<:Union{<:SparseVector,<:Vector}},
::typeof(canonical),
::Vararg{Type},
)
return MA.IsMutable()
end
MA.operate!(::typeof(canonical), sv::SparseVector) = dropzeros!(sv)
MA.operate!(::typeof(canonical), v::Vector) = v

function Base.:(==)(ac1::AbstractCoefficients, ac2::AbstractCoefficients)
ac1 = MA.operate!!(canonical, ac1)
ac2 = MA.operate!!(canonical, ac2)
MA.operate!(canonical, ac1)
MA.operate!(canonical, ac2)
all(x -> ==(x...), zip(keys(ac1), keys(ac2))) || return false
all(x -> ==(x...), zip(values(ac1), values(ac2))) || return false
return true
end

function Base.hash(ac::AbstractCoefficients, h::UInt)
ac = MA.operate!!(canonical, ac)
MA.operate!(canonical, ac)
return foldl((h, i) -> hash(i, h), nonzero_pairs(ac); init = h)
end

Expand Down
5 changes: 3 additions & 2 deletions src/diracs_augmented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function Base.getindex(aδ::Augmented{K}, i::K) where {K}
return zero(w)
end

canonical(::Augmented) =
MA.operate!(::typeof(canonical), ::Augmented) =

Base.keys(aδ::Augmented) = (k = keys(aδ.elt); (one(first(k)), first(k)))
function Base.values(aδ::Augmented)
Expand Down Expand Up @@ -129,5 +129,6 @@ function coeffs!(
SparseCoefficients((target[Augmented(x)],), (1,)),
)
end
return MA.operate!!(canonical, res)
MA.operate!(canonical, res)
return res
end
5 changes: 3 additions & 2 deletions src/mstructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ function MA.operate_to!(res, ms::MultiplicativeStructure, v, w)
throw(ArgumentError("No alias allowed"))
end
MA.operate!(zero, res)
res = MA.operate!(UnsafeAddMul(ms), res, v, w)
return MA.operate!!(canonical, res)
MA.operate!(UnsafeAddMul(ms), res, v, w)
MA.operate!(canonical, res)
return res
end

function MA.operate!(
Expand Down
2 changes: 1 addition & 1 deletion test/test_example_acoeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ACoeffs(v::AbstractVector) = ACoeffs{valtype(v)}(v)
## Basic API
Base.keys(ac::ACoeffs) = (k for (k, v) in pairs(ac.vals) if !iszero(v))
Base.values(ac::ACoeffs) = (v for v in ac.vals if !iszero(v))
SA.canonical(ac::ACoeffs) = ac
MA.operate!(::typeof(SA.canonical), ac::ACoeffs) = ac
function SA.star(b::SA.AbstractBasis, ac::ACoeffs)
return ACoeffs([ac.vals[star(b, k)] for k in keys(ac)])
end
Expand Down

0 comments on commit a802e6b

Please sign in to comment.