Skip to content

Commit

Permalink
Assume mutability for canonical (#31)
Browse files Browse the repository at this point in the history
* Assume mutability for canonical

* Fix allocation test

* Debug for Julia v1.6

* Use 0.7
  • Loading branch information
blegat authored May 27, 2024
1 parent 8255fe5 commit 69f107e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
19 changes: 7 additions & 12 deletions src/coefficients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,25 @@ If `ac` can be brought to canonical form in-place one has to implement
otherwise `canonical(ac)` needs to be implemented.
"""
canonical(ac::AbstractCoefficients) = ac
MA.operate(::typeof(canonical), x) = canonical(x) # fallback?
function canonical end
function MA.promote_operation(::typeof(canonical), ::Type{C}) where {C}
return C

Check warning on line 53 in src/coefficients.jl

View check run for this annotation

Codecov / codecov/patch

src/coefficients.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
end

# example implementation for vectors
function MA.mutability(
::Type{<:Union{<:SparseVector,<:Vector}},
::typeof(canonical),
::Vararg{Type},
)
return MA.IsMutable()
end
MA.operate!(::typeof(canonical), sv::SparseVector) = dropzeros!(sv)
MA.operate!(::typeof(canonical), v::Vector) = v

function Base.:(==)(ac1::AbstractCoefficients, ac2::AbstractCoefficients)
ac1 = MA.operate!!(canonical, ac1)
ac2 = MA.operate!!(canonical, ac2)
MA.operate!(canonical, ac1)
MA.operate!(canonical, ac2)
all(x -> ==(x...), zip(keys(ac1), keys(ac2))) || return false
all(x -> ==(x...), zip(values(ac1), values(ac2))) || return false
return true
end

function Base.hash(ac::AbstractCoefficients, h::UInt)
ac = MA.operate!!(canonical, ac)
MA.operate!(canonical, ac)
return foldl((h, i) -> hash(i, h), nonzero_pairs(ac); init = h)
end

Expand Down
5 changes: 3 additions & 2 deletions src/diracs_augmented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function Base.getindex(aδ::Augmented{K}, i::K) where {K}
return zero(w)

Check warning on line 18 in src/diracs_augmented.jl

View check run for this annotation

Codecov / codecov/patch

src/diracs_augmented.jl#L18

Added line #L18 was not covered by tests
end

canonical(::Augmented) =
MA.operate!(::typeof(canonical), ::Augmented) =

Base.keys(aδ::Augmented) = (k = keys(aδ.elt); (one(first(k)), first(k)))
function Base.values(aδ::Augmented)
Expand Down Expand Up @@ -129,5 +129,6 @@ function coeffs!(
SparseCoefficients((target[Augmented(x)],), (1,)),
)
end
return MA.operate!!(canonical, res)
MA.operate!(canonical, res)
return res
end
5 changes: 3 additions & 2 deletions src/mstructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ function MA.operate_to!(res, ms::MultiplicativeStructure, v, w)
throw(ArgumentError("No alias allowed"))
end
MA.operate!(zero, res)
res = MA.operate!(UnsafeAddMul(ms), res, v, w)
return MA.operate!!(canonical, res)
MA.operate!(UnsafeAddMul(ms), res, v, w)
MA.operate!(canonical, res)
return res
end

function MA.operate!(
Expand Down
5 changes: 4 additions & 1 deletion src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ function MA.operate!(::typeof(canonical), res::SparseCoefficients)
return MA.operate!(canonical, res, comparable(key_type(res)))
end

function MA.operate!(::typeof(canonical), res::SparseCoefficients, cmp)
# `::C` is needed to force Julia specialize on the function type
# Otherwise, we get one allocation when we call `issorted`
# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing
function MA.operate!(::typeof(canonical), res::SparseCoefficients, cmp::C) where {C}
sorted = issorted(res.basis_elements; lt = cmp)
distinct = allunique(res.basis_elements)
if sorted && distinct && !any(iszero, res.values)
Expand Down
2 changes: 1 addition & 1 deletion test/caching_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
@test Y * Y isa AlgebraElement
Y * Y
k2 = @allocated Y * Y
@test k2 / k1 < 0.5
@test k2 / k1 < 0.7
end

@test all(!iszero, SA.mstructure(fRG).table)
Expand Down
2 changes: 1 addition & 1 deletion test/test_example_acoeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ACoeffs(v::AbstractVector) = ACoeffs{valtype(v)}(v)
## Basic API
Base.keys(ac::ACoeffs) = (k for (k, v) in pairs(ac.vals) if !iszero(v))
Base.values(ac::ACoeffs) = (v for v in ac.vals if !iszero(v))
SA.canonical(ac::ACoeffs) = ac
MA.operate!(::typeof(SA.canonical), ac::ACoeffs) = ac
function SA.star(b::SA.AbstractBasis, ac::ACoeffs)
return ACoeffs([ac.vals[star(b, k)] for k in keys(ac)])
end
Expand Down

0 comments on commit 69f107e

Please sign in to comment.