From 653d90d4b6e19e80286a95013bc23cb90785b39c Mon Sep 17 00:00:00 2001 From: Marek Kaluba Date: Wed, 15 May 2024 13:19:04 +0200 Subject: [PATCH] fix norm/dot and provide coverage --- src/algebra_elts.jl | 13 +++++++++---- src/coefficients.jl | 28 ++++++++++++++++++++++++---- test/monoid_algebra.jl | 8 +++++--- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/algebra_elts.jl b/src/algebra_elts.jl index 8644903..e2e160b 100644 --- a/src/algebra_elts.jl +++ b/src/algebra_elts.jl @@ -28,9 +28,14 @@ end function LinearAlgebra.norm(a::AlgebraElement, p::Real) return LinearAlgebra.norm(coeffs(a), p) end -function LinearAlgebra.dot(a::AlgebraElement, v::AbstractVector) - return LinearAlgebra.dot(coeffs(a), v) + +function LinearAlgebra.dot(a::AlgebraElement, b::AlgebraElement) + return LinearAlgebra.dot(coeffs(a), coeffs(b)) +end + +function LinearAlgebra.dot(w::AbstractVector, b::AlgebraElement) + return LinearAlgebra.dot(w, coeffs(b)) end -function LinearAlgebra.dot(v::AbstractVector, a::AlgebraElement) - return LinearAlgebra.dot(a, v) +function LinearAlgebra.dot(a::AlgebraElement, w::AbstractVector) + return LinearAlgebra.dot(coeffs(a), w) end diff --git a/src/coefficients.jl b/src/coefficients.jl index df104f9..735be39 100644 --- a/src/coefficients.jl +++ b/src/coefficients.jl @@ -92,14 +92,34 @@ end end function LinearAlgebra.norm(sc::AbstractCoefficients, p::Real) - isempty(keys(sc)) && return (0^p)^1 / p - return sum(v^p for v in values(sc))^1 / p + isempty(values(sc)) && return (0^p)^(1 / p) + return sum(abs(v)^p for v in values(sc))^(1 / p) +end + +function LinearAlgebra.dot(ac::AbstractCoefficients, bc::AbstractCoefficients) + if isempty(values(ac)) || isempty(values(bc)) + return zero(Base._return_type(*, Tuple{valtype(ac),valtype(bc)})) + else + return sum(c * star(bc[i]) for (i, c) in nonzero_pairs(ac)) + end +end + +function LinearAlgebra.dot(w::AbstractVector, ac::AbstractCoefficients) + @assert key_type(ac) <: Integer + if isempty(values(ac)) + return zero(Base._return_type(*, eltype(w), valtype(ac))) + else + return sum(w[i] * star(v) for (i, v) in nonzero_pairs(ac)) + end end function LinearAlgebra.dot(ac::AbstractCoefficients, w::AbstractVector) @assert key_type(ac) <: Integer - isempty(keys(ac)) && return zero(eltype(v)) - return sum(v * w[i] for (i, v) in nonzero_pairs(ac)) + if isempty(values(ac)) + return zero(Base._return_type(*, eltype(w), valtype(ac))) + else + return sum(v * star(w[i]) for (i, v) in nonzero_pairs(ac)) + end end # general mutable API diff --git a/test/monoid_algebra.jl b/test/monoid_algebra.jl index 094913e..92b22cc 100644 --- a/test/monoid_algebra.jl +++ b/test/monoid_algebra.jl @@ -46,9 +46,11 @@ fX = AlgebraElement(coeffs(X, basis(fRG)), fRG) fY = AlgebraElement(coeffs(Y, basis(fRG)), fRG) - @test LinearAlgebra.dot(fX, coeffs(fX)) ≈ - norm(fX)^2 ≈ - LinearAlgebra.dot(coeffs(fX), fX) + @test dot(X, Y) == dot(fX, fY) == dot(coeffs(X), coeffs(Y)) + @test dot(fX, coeffs(fY)) == dot(coeffs(fX), fY) + + @test dot(X, X) ≈ norm(X)^2 ≈ dot(coeffs(X), coeffs(X)) + @test dot(fX, fX) ≈ norm(fX)^2 ≈ dot(coeffs(fX), coeffs(fX)) @test coeffs(fX) == coeffs(coeffs(fX, basis(RG)), basis(RG), basis(fRG))