Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jun 5, 2024
1 parent b1acfd0 commit a1fea31
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ Implicit bases are not stored in memory and can be potentially infinite.
"""
abstract type ImplicitBasis{T,I} <: AbstractBasis{T,I} end

function zero_coeffs(::Type{S}, ::ImplicitBasis{T}) where {S,T}
return SparseCoefficients(T[], S[])
function zero_coeffs(::Type{S}, ::ImplicitBasis{T,I}) where {S,T,I}
return SparseCoefficients(I[], S[])
end

"""
Expand Down
54 changes: 33 additions & 21 deletions src/mstructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,54 @@ struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}}
structure::M
end

function MA.operate_to!(res, ms::MultiplicativeStructure, v, w)
if res === v || res === w
function MA.operate_to!(res, ms::MultiplicativeStructure, args::Vararg{Any,N}) where {N}
if any(Base.Fix1(===, res), args)
throw(ArgumentError("No alias allowed"))
end
MA.operate!(zero, res)
MA.operate!(UnsafeAddMul(ms), res, v, w)
MA.operate!(UnsafeAddMul(ms), res, args...)
MA.operate!(canonical, res)
return res
end

function MA.operate!(
::UnsafeAddMul{typeof(*)},
mc::SparseCoefficients,
val,
c::AbstractCoefficients,
)
append!(mc.basis_elements, keys(c))
vals = values(c)
if vals isa AbstractVector
append!(mc.values, val .* vals)
else
append!(mc.values, val * collect(values(c)))
struct One end
Base.:*(::One, α) = α

function operate_with_constant!(::UnsafeAddMul, res, α, c)
for (k, v) in nonzero_pairs(c)
unsafe_push!(res, k, α * v)
end
return mc
return res
end

function MA.operate!(ms::UnsafeAddMul, res, v, w)
for (kv, a) in nonzero_pairs(v)
for (kw, b) in nonzero_pairs(w)
c = ms.structure(kv, kw)
MA.operate!(UnsafeAddMul(*), res, a * b, c)
function operate_with_constant!(op::UnsafeAddMul, res, α, b, c, args::Vararg{Any, N}) where {N}
for (kb, vb) in nonzero_pairs(b)
for (kc, vc) in nonzero_pairs(c)
operate_with_constant!(op, res, α * vb * vc, op.structure(kb, kc), args...)
end
end
return res
end

_aggregate_constants(constant, non_constant) = (constant, non_constant)

function _aggregate_constants(constant, non_constant, α, args::Vararg{Any,N}) where {N}
return _aggregate_constants(constant * α, non_constant, args...)
end

function _aggregate_constants(constant, non_constant, c::AbstractCoefficients, args::Vararg{Any,N}) where {N}
return _aggregate_constants(constant, (non_constant..., c), args...)
end

function MA.operate!(
op::UnsafeAddMul,
mc::AbstractCoefficients,
args::Vararg{Any,N},
) where {N}
constant, non_constant = _aggregate_constants(One(), tuple(), args...)
return operate_with_constant!(op, mc, constant, non_constant...)
end

struct DiracMStructure{Op} <: MultiplicativeStructure
op::Op
end
Expand Down
6 changes: 6 additions & 0 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ function MA.operate!(::typeof(canonical), res::SparseCoefficients)
return MA.operate!(canonical, res, comparable(key_type(res)))
end

function unsafe_push!(res::SparseCoefficients, key, value)
push!(res.basis_elements, key)
push!(res.values, value)
return res
end

# `::C` is needed to force Julia specialize on the function type
# Otherwise, we get one allocation when we call `issorted`
# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing
Expand Down

0 comments on commit a1fea31

Please sign in to comment.