From a1fea311be307ecbced1ebddf01f5c24ce8a4a3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 5 Jun 2024 18:45:16 +0200 Subject: [PATCH] Fixes --- src/bases.jl | 4 ++-- src/mstructures.jl | 54 +++++++++++++++++++++++++++----------------- src/sparse_coeffs.jl | 6 +++++ 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index afd5951..813b880 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -36,8 +36,8 @@ Implicit bases are not stored in memory and can be potentially infinite. """ abstract type ImplicitBasis{T,I} <: AbstractBasis{T,I} end -function zero_coeffs(::Type{S}, ::ImplicitBasis{T}) where {S,T} - return SparseCoefficients(T[], S[]) +function zero_coeffs(::Type{S}, ::ImplicitBasis{T,I}) where {S,T,I} + return SparseCoefficients(I[], S[]) end """ diff --git a/src/mstructures.jl b/src/mstructures.jl index 2cf3e73..d92331e 100644 --- a/src/mstructures.jl +++ b/src/mstructures.jl @@ -44,42 +44,54 @@ struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}} structure::M end -function MA.operate_to!(res, ms::MultiplicativeStructure, v, w) - if res === v || res === w +function MA.operate_to!(res, ms::MultiplicativeStructure, args::Vararg{Any,N}) where {N} + if any(Base.Fix1(===, res), args) throw(ArgumentError("No alias allowed")) end MA.operate!(zero, res) - MA.operate!(UnsafeAddMul(ms), res, v, w) + MA.operate!(UnsafeAddMul(ms), res, args...) MA.operate!(canonical, res) return res end -function MA.operate!( - ::UnsafeAddMul{typeof(*)}, - mc::SparseCoefficients, - val, - c::AbstractCoefficients, -) - append!(mc.basis_elements, keys(c)) - vals = values(c) - if vals isa AbstractVector - append!(mc.values, val .* vals) - else - append!(mc.values, val * collect(values(c))) +struct One end +Base.:*(::One, α) = α + +function operate_with_constant!(::UnsafeAddMul, res, α, c) + for (k, v) in nonzero_pairs(c) + unsafe_push!(res, k, α * v) end - return mc + return res end -function MA.operate!(ms::UnsafeAddMul, res, v, w) - for (kv, a) in nonzero_pairs(v) - for (kw, b) in nonzero_pairs(w) - c = ms.structure(kv, kw) - MA.operate!(UnsafeAddMul(*), res, a * b, c) +function operate_with_constant!(op::UnsafeAddMul, res, α, b, c, args::Vararg{Any, N}) where {N} + for (kb, vb) in nonzero_pairs(b) + for (kc, vc) in nonzero_pairs(c) + operate_with_constant!(op, res, α * vb * vc, op.structure(kb, kc), args...) end end return res end +_aggregate_constants(constant, non_constant) = (constant, non_constant) + +function _aggregate_constants(constant, non_constant, α, args::Vararg{Any,N}) where {N} + return _aggregate_constants(constant * α, non_constant, args...) +end + +function _aggregate_constants(constant, non_constant, c::AbstractCoefficients, args::Vararg{Any,N}) where {N} + return _aggregate_constants(constant, (non_constant..., c), args...) +end + +function MA.operate!( + op::UnsafeAddMul, + mc::AbstractCoefficients, + args::Vararg{Any,N}, +) where {N} + constant, non_constant = _aggregate_constants(One(), tuple(), args...) + return operate_with_constant!(op, mc, constant, non_constant...) +end + struct DiracMStructure{Op} <: MultiplicativeStructure op::Op end diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index 5f07251..1affb37 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -68,6 +68,12 @@ function MA.operate!(::typeof(canonical), res::SparseCoefficients) return MA.operate!(canonical, res, comparable(key_type(res))) end +function unsafe_push!(res::SparseCoefficients, key, value) + push!(res.basis_elements, key) + push!(res.values, value) + return res +end + # `::C` is needed to force Julia specialize on the function type # Otherwise, we get one allocation when we call `issorted` # See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing