Skip to content

Commit

Permalink
replace args::Vararg multiplication by 4-argument one
Browse files Browse the repository at this point in the history
  • Loading branch information
kalmarek committed Dec 9, 2024
1 parent feb307c commit 0ea0688
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
37 changes: 18 additions & 19 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ for op in [:+, :-, :*]
end
end

function Base.:*(args::Vararg{AlgebraElement,N}) where {N}
return MA.operate_to!(_preallocate_output(*, args...), *, args...)
end

Base.:^(a::AlgebraElement, p::Integer) = Base.power_by_squaring(a, p)

# mutable API
Expand Down Expand Up @@ -130,34 +126,37 @@ function MA.operate_to!(
X::AlgebraElement,
Y::AlgebraElement,
)
@assert parent(res) === parent(X) === parent(Y)
@assert parent(res) == parent(X)
@assert parent(X) == parent(Y)
MA.operate_to!(coeffs(res), -, coeffs(X), coeffs(Y))
return res
end

function MA.operate_to!(
res::AlgebraElement,
::typeof(*),
args::Vararg{AlgebraElement,N},
) where {N}
for arg in args
@assert parent(res) == parent(arg)
end
A::AlgebraElement,
B::AlgebraElement,
α = true,
)
@assert parent(res) == parent(A)
@assert parent(res) == parent(B)
mstr = mstructure(basis(res))
MA.operate_to!(coeffs(res), mstr, coeffs.(args)...)
MA.operate_to!(coeffs(res), mstr, coeffs(A), coeffs(B), α)
return res
end

function MA.operate!(
::UnsafeAddMul{typeof(*)},
function MA.operate_to!(
res::AlgebraElement,
args::Vararg{AlgebraElement,N},
) where {N}
for arg in args
@assert parent(res) == parent(arg)
end
mul::UnsafeAddMul{typeof(*)},
A::AlgebraElement,
B::AlgebraElement,
α = true,
)
@assert parent(res) == parent(A)
@assert parent(res) == parent(B)
mstr = mstructure(basis(res))
MA.operate!(UnsafeAddMul(mstr), coeffs(res), coeffs.(args)...)
MA.operate_to!(coeffs(res), mul, coeffs(A), coeffs(B), α)
return res
end

Expand Down
62 changes: 31 additions & 31 deletions src/mstructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,53 +31,53 @@ When the product is not representable faithfully,
"""
abstract type MultiplicativeStructure end

"""
struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}}
structure::M
end
The value of `(op::UnsafeAddMul)(a, b, c)` is `a + structure(b, c)`
where `a` is not expected to be canonicalized before the operation `+`
and should not be expected to be canonicalized after either.
"""
struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}}
structure::M
end

function MA.operate_to!(res, ms::MultiplicativeStructure, args::Vararg{Any,N}) where {N}
if any(Base.Fix1(===, res), args)
throw(ArgumentError("No alias allowed"))
"""
operate_to!(res, ms::MultiplicativeStructure, A, B[, α = true])
Compute `α·A·B` storing the result in `res`. Return `res`.
`res` is assumed to behave like `AbstractCoefficients` and not aliased with any
other arguments.
`A` and `B` are assumed to behave like `AbstractCoefficients`, while `α` will
be treated as a scalar.
Canonicalization of the result happens only once at the end of the operation.
"""
function MA.operate_to!(res, ms::MultiplicativeStructure, A, B, α = true)
if res === A || res === B
throw(
ArgumentError(
"Aliasing arguments in multiplication is not supported",
),
)
end
MA.operate!(zero, res)
MA.operate!(UnsafeAddMul(ms), res, args...)
res = MA.operate_to!(res, UnsafeAddMul(ms), A, B, α)
MA.operate!(canonical, res)
return res
end

function MA.operate!(::UnsafeAddMul, res, c)
for (k, v) in nonzero_pairs(c)
struct UnsafeAdd end

function MA.operate_to!(res, ::UnsafeAdd, b)
for (k, v) in nonzero_pairs(b)
unsafe_push!(res, k, v)
end
return res
end

function MA.operate!(
op::UnsafeAddMul,
res,
b,
c,
args::Vararg{Any,N};
cfs = nothing,
) where {N}
for (kb, vb) in nonzero_pairs(b)
for (kc, vc) in nonzero_pairs(c)
for (k, v) in nonzero_pairs(op.structure(kb, kc))
_cfs = isnothing(cfs) ? vb * vc * v : vb * vc * v * cfs
MA.operate!(
op,
function MA.operate_to!(res, op::UnsafeAddMul, A, B, α = true)
for (kA, vA) in nonzero_pairs(A)
for (kB, vB) in nonzero_pairs(B)
for (k, v) in nonzero_pairs(op.structure(kA, kB))
cfs = MA.@rewrite α * vA * vB * v
MA.operate_to!(
res,
SparseCoefficients((_key(op.structure, k),), (_cfs,)),
args...,
UnsafeAdd(),
SparseCoefficients((_key(op.structure, k),), (cfs,)),
)
end
end
Expand Down
7 changes: 2 additions & 5 deletions test/monoid_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,8 @@
@test @allocated(MA.operate_to!(d, *, 2, a)) == 0
@test d == 2a

MA.operate!(zero, d)
MA.operate!(SA.UnsafeAddMul(*), d, a, b, b)
MA.operate!(SA.canonical, SA.coeffs(d))
@test a * b^2 == *(a, b, b)
@test d == *(a, b, b)
MA.operate_to!(d, *, a, b, 2)
@test d == 2 * a * b
end
end
end

0 comments on commit 0ea0688

Please sign in to comment.