diff --git a/src/algebra_elts.jl b/src/algebra_elts.jl index 8644903..e2e160b 100644 --- a/src/algebra_elts.jl +++ b/src/algebra_elts.jl @@ -28,9 +28,14 @@ end function LinearAlgebra.norm(a::AlgebraElement, p::Real) return LinearAlgebra.norm(coeffs(a), p) end -function LinearAlgebra.dot(a::AlgebraElement, v::AbstractVector) - return LinearAlgebra.dot(coeffs(a), v) + +function LinearAlgebra.dot(a::AlgebraElement, b::AlgebraElement) + return LinearAlgebra.dot(coeffs(a), coeffs(b)) +end + +function LinearAlgebra.dot(w::AbstractVector, b::AlgebraElement) + return LinearAlgebra.dot(w, coeffs(b)) end -function LinearAlgebra.dot(v::AbstractVector, a::AlgebraElement) - return LinearAlgebra.dot(a, v) +function LinearAlgebra.dot(a::AlgebraElement, w::AbstractVector) + return LinearAlgebra.dot(coeffs(a), w) end diff --git a/src/coefficients.jl b/src/coefficients.jl index df104f9..735be39 100644 --- a/src/coefficients.jl +++ b/src/coefficients.jl @@ -92,14 +92,34 @@ end end function LinearAlgebra.norm(sc::AbstractCoefficients, p::Real) - isempty(keys(sc)) && return (0^p)^1 / p - return sum(v^p for v in values(sc))^1 / p + isempty(values(sc)) && return (0^p)^(1 / p) + return sum(abs(v)^p for v in values(sc))^(1 / p) +end + +function LinearAlgebra.dot(ac::AbstractCoefficients, bc::AbstractCoefficients) + if isempty(values(ac)) || isempty(values(bc)) + return zero(Base._return_type(*, Tuple{valtype(ac),valtype(bc)})) + else + return sum(c * star(bc[i]) for (i, c) in nonzero_pairs(ac)) + end +end + +function LinearAlgebra.dot(w::AbstractVector, ac::AbstractCoefficients) + @assert key_type(ac) <: Integer + if isempty(values(ac)) + return zero(Base._return_type(*, eltype(w), valtype(ac))) + else + return sum(w[i] * star(v) for (i, v) in nonzero_pairs(ac)) + end end function LinearAlgebra.dot(ac::AbstractCoefficients, w::AbstractVector) @assert key_type(ac) <: Integer - isempty(keys(ac)) && return zero(eltype(v)) - return sum(v * w[i] for (i, v) in nonzero_pairs(ac)) + if isempty(values(ac)) + return zero(Base._return_type(*, eltype(w), valtype(ac))) + else + return sum(v * star(w[i]) for (i, v) in nonzero_pairs(ac)) + end end # general mutable API diff --git a/test/abstract_coeffs.jl b/test/abstract_coeffs.jl index a1b9df9..c5b6536 100644 --- a/test/abstract_coeffs.jl +++ b/test/abstract_coeffs.jl @@ -100,4 +100,8 @@ end @test fP2m * fP3 == fP3 * fP2m == fPAlt @test iszero(fP2m * fP2) + + @test norm(fP2m) == norm(P2m) == norm(fP2m) + v = coeffs(P2m, basis(fRG)) # an honest vector + @test dot(fP2m, fP2m) == dot(coeffs(fP2m), v) == dot(v, coeffs(fP2m)) end diff --git a/test/monoid_algebra.jl b/test/monoid_algebra.jl index 094913e..92b22cc 100644 --- a/test/monoid_algebra.jl +++ b/test/monoid_algebra.jl @@ -46,9 +46,11 @@ fX = AlgebraElement(coeffs(X, basis(fRG)), fRG) fY = AlgebraElement(coeffs(Y, basis(fRG)), fRG) - @test LinearAlgebra.dot(fX, coeffs(fX)) ≈ - norm(fX)^2 ≈ - LinearAlgebra.dot(coeffs(fX), fX) + @test dot(X, Y) == dot(fX, fY) == dot(coeffs(X), coeffs(Y)) + @test dot(fX, coeffs(fY)) == dot(coeffs(fX), fY) + + @test dot(X, X) ≈ norm(X)^2 ≈ dot(coeffs(X), coeffs(X)) + @test dot(fX, fX) ≈ norm(fX)^2 ≈ dot(coeffs(fX), coeffs(fX)) @test coeffs(fX) == coeffs(coeffs(fX, basis(RG)), basis(RG), basis(fRG))