From 39e964608dc93bff492c75b596cc4223b49f3af0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 2 Jul 2024 15:39:56 +0200 Subject: [PATCH] Implement promote_operation --- src/arithmetic.jl | 35 +++++++++++++++++++++++++---------- src/coefficients.jl | 3 +++ src/sparse_coeffs.jl | 7 +++++++ src/types.jl | 14 +++++++++++++- test/caching_allocations.jl | 16 ++++++++++++---- test/constructors.jl | 3 +++ 6 files changed, 63 insertions(+), 15 deletions(-) diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 418c069..9645323 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -1,5 +1,14 @@ -_coeff_type(X::AlgebraElement) = eltype(X) -_coeff_type(a) = typeof(a) +_coeff_type(::Type{A}) where {A<:AlgebraElement} = eltype(A) +_coeff_type(a::Type) = a +_coeff_type(a) = _coeff_type(typeof(a)) + +function algebra_promote_operation(op, args::Vararg{Type,N}) where {N} + T = MA.promote_operation(op, _coeff_type.(args)...) + if args[2] <: AlgebraElement && MA.promote_operation(coeffs, args[2]) <: DenseArray # what a hack :) + return similar_type(args[2], T) + end + return similar_type(args[1], T) +end function _preallocate_output(op, args::Vararg{Any,N}) where {N} T = MA.promote_operation(op, _coeff_type.(args)...) @@ -18,6 +27,9 @@ Base.:(//)(X::AlgebraElement, a::Number) = 1 // a * X function Base.:-(X::AlgebraElement) return MA.operate_to!(_preallocate_output(*, X, -1), -, X) end +function MA.promote_operation(::typeof(*), ::Type{T}, ::Type{A}) where {T<:Number, A<:AlgebraElement} + return algebra_promote_operation(*, A, T) +end function Base.:*(a::Number, X::AlgebraElement) return MA.operate_to!(_preallocate_output(*, X, a), *, X, a) end @@ -25,18 +37,21 @@ function Base.:div(X::AlgebraElement, a::Number) return MA.operate_to!(_preallocate_output(div, X, a), div, X, a) end -function Base.:+(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(+, X, Y), +, X, Y) -end -function Base.:-(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(-, X, Y), -, X, Y) -end -function Base.:*(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(*, X, Y), *, X, Y) +for op in [:+, :-, :*] + @eval begin + function MA.promote_operation(::typeof($op), ::Type{X}, ::Type{Y}) where {X<:AlgebraElement,Y<:AlgebraElement} + return algebra_promote_operation($op, X, Y) + end + function Base.$op(X::AlgebraElement, Y::AlgebraElement) + return MA.operate_to!(_preallocate_output($op, X, Y), $op, X, Y) + end + end end + function Base.:*(args::Vararg{AlgebraElement,N}) where {N} return MA.operate_to!(_preallocate_output(*, args...), *, args...) end + Base.:^(a::AlgebraElement, p::Integer) = Base.power_by_squaring(a, p) # mutable API diff --git a/src/coefficients.jl b/src/coefficients.jl index 87c93fc..d6edde2 100644 --- a/src/coefficients.jl +++ b/src/coefficients.jl @@ -38,6 +38,9 @@ Base.iszero(ac::AbstractCoefficients) = isempty(keys(ac)) Base.similar(ac::AbstractCoefficients) = similar(ac, valtype(ac)) +similar_type(::Type{<:Vector}, ::Type{T}) where {T} = Vector{T} +similar_type(::Type{<:SparseArrays.SparseVector{C,I}}, ::Type{T}) where {C,I,T} = SparseArrays.SparseVector{T,I} + """ canonical(ac::AbstractCoefficients) Compute the canonical form of `ac` (e.g. grouping coefficients together, etc). diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index ff5cb27..1a4300d 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -47,6 +47,13 @@ _similar(x::Tuple, ::Type{T}) where {T} = Vector{T}(undef, length(x)) _similar(x) = similar(x) _similar(x, ::Type{T}) where {T} = similar(x, T) +_similar_type(::Type{<:Tuple}, ::Type{T}) where {T} = Vector{T} +_similar_type(::Type{V}, ::Type{T}) where {V,T} = similar_type(V, T) + +function similar_type(::Type{SparseCoefficients{K,V,Vk,Vv}}, ::Type{T}) where {K,V,Vk,Vv,T} + return SparseCoefficients{K,T,_similar_type(Vk, K),_similar_type(Vv, T)} +end + function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T} return SparseCoefficients(_similar(s.basis_elements), _similar(s.values, T)) end diff --git a/src/types.jl b/src/types.jl index 2d015e3..e3bd5bf 100644 --- a/src/types.jl +++ b/src/types.jl @@ -25,6 +25,7 @@ struct StarAlgebra{O,T,B<:AbstractBasis{T}} <: AbstractStarAlgebra{O,T} end basis(A::StarAlgebra) = A.basis +MA.promote_operation(::typeof(basis), ::Type{StarAlgebra{O,T,B}}) where {O,T,B} = B object(A::StarAlgebra) = A.object struct AlgebraElement{A,T,V} <: MA.AbstractMutable @@ -33,7 +34,11 @@ struct AlgebraElement{A,T,V} <: MA.AbstractMutable end Base.parent(a::AlgebraElement) = a.parent -Base.eltype(a::AlgebraElement) = valtype(coeffs(a)) +Base.eltype(::Type{A}) where {A<:AlgebraElement} = valtype(MA.promote_operation(coeffs, A)) +Base.eltype(a::AlgebraElement) = eltype(typeof(a)) +function MA.promote_operation(::typeof(coeffs), ::Type{AlgebraElement{A,T,V}}) where {A,T,V} + return V +end coeffs(a::AlgebraElement) = a.coeffs function coeffs(x::AlgebraElement, b::AbstractBasis) return coeffs(coeffs(x), basis(x), b) @@ -41,6 +46,9 @@ end function adjoint_coeffs(a::AlgebraElement, target::AbstractBasis) return adjoint_coeffs(coeffs(a), target, basis(a)) end +function MA.promote_operation(::typeof(basis), ::Type{<:AlgebraElement{A}}) where {A} + return MA.promote_operation(basis, A) +end basis(a::AlgebraElement) = basis(parent(a)) function AlgebraElement(coeffs, A::AbstractStarAlgebra) @@ -106,6 +114,10 @@ end (A::AbstractStarAlgebra)(x::Number) = x * one(A) +function similar_type(::Type{AlgebraElement{A,T,V}}, ::Type{C}) where {A,T,V,C} + return AlgebraElement{A,C,similar_type(V, C)} +end + function Base.similar(X::AlgebraElement, T = eltype(X)) return AlgebraElement(similar(coeffs(X), T), parent(X)) end diff --git a/test/caching_allocations.jl b/test/caching_allocations.jl index 307a223..30a4b4a 100644 --- a/test/caching_allocations.jl +++ b/test/caching_allocations.jl @@ -8,6 +8,12 @@ end end +function _test_op(op, a, b) + result = @inferred op(a, b) + @test typeof(result) == MA.promote_operation(op, typeof(a), typeof(b)) + return result +end + @testset "FixedBasis caching && allocations" begin alph = [:a, :b, :c] A★ = FreeWords(alph) @@ -88,8 +94,10 @@ end (1,), ) Z = AlgebraElement(z, fRG) - @test Z + Z == 2 * Z - @test Z + Z == Y + Y - @test Y + Z == Y + Y - @test Z + Y == Y + Y + @test _test_op(+, Z, Z) == _test_op(*, 2, Z) + @test _test_op(+, Z, Z) == _test_op(+, Y, Y) + @test _test_op(+, Y, Z) == _test_op(+, Y, Y) + @test _test_op(+, Z, Y) == _test_op(+, Y, Y) + @test _test_op(-, Z, Z) == _test_op(*, 0, Z) + @test _test_op(-, Z, Z) == _test_op(-, Y, Z) end diff --git a/test/constructors.jl b/test/constructors.jl index 89f6a96..0fdf940 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -13,6 +13,7 @@ end A★ = FreeWords(alph) B = SA.DiracBasis(A★) RG = StarAlgebra(A★, B) + @test typeof(@inferred basis(RG)) == MA.promote_operation(basis, typeof(RG)) @test typeof(zero(RG)) == typeof(RG(0)) @test typeof(one(RG)) == typeof(RG(1)) @@ -33,6 +34,8 @@ end @test AlgebraElement(x, RG) isa AlgebraElement X = AlgebraElement(x, RG) + @test typeof(@inferred coeffs(X)) == MA.promote_operation(coeffs, typeof(X)) + @test typeof(@inferred basis(X)) == MA.promote_operation(basis, typeof(X)) @test AlgebraElement{Float64}(X) isa AlgebraElement Y = AlgebraElement{Float64}(X)