Skip to content

Commit

Permalink
Implement promote_operation
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jul 2, 2024
1 parent f9ff304 commit 39e9646
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 15 deletions.
35 changes: 25 additions & 10 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
_coeff_type(X::AlgebraElement) = eltype(X)
_coeff_type(a) = typeof(a)
_coeff_type(::Type{A}) where {A<:AlgebraElement} = eltype(A)
_coeff_type(a::Type) = a
_coeff_type(a) = _coeff_type(typeof(a))

function algebra_promote_operation(op, args::Vararg{Type,N}) where {N}
T = MA.promote_operation(op, _coeff_type.(args)...)
if args[2] <: AlgebraElement && MA.promote_operation(coeffs, args[2]) <: DenseArray # what a hack :)
return similar_type(args[2], T)

Check warning on line 8 in src/arithmetic.jl

View check run for this annotation

Codecov / codecov/patch

src/arithmetic.jl#L8

Added line #L8 was not covered by tests
end
return similar_type(args[1], T)
end

function _preallocate_output(op, args::Vararg{Any,N}) where {N}
T = MA.promote_operation(op, _coeff_type.(args)...)
Expand All @@ -18,25 +27,31 @@ Base.:(//)(X::AlgebraElement, a::Number) = 1 // a * X
function Base.:-(X::AlgebraElement)
return MA.operate_to!(_preallocate_output(*, X, -1), -, X)
end
function MA.promote_operation(::typeof(*), ::Type{T}, ::Type{A}) where {T<:Number, A<:AlgebraElement}
return algebra_promote_operation(*, A, T)
end
function Base.:*(a::Number, X::AlgebraElement)
return MA.operate_to!(_preallocate_output(*, X, a), *, X, a)
end
function Base.:div(X::AlgebraElement, a::Number)
return MA.operate_to!(_preallocate_output(div, X, a), div, X, a)
end

function Base.:+(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(+, X, Y), +, X, Y)
end
function Base.:-(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(-, X, Y), -, X, Y)
end
function Base.:*(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(*, X, Y), *, X, Y)
for op in [:+, :-, :*]
@eval begin
function MA.promote_operation(::typeof($op), ::Type{X}, ::Type{Y}) where {X<:AlgebraElement,Y<:AlgebraElement}
return algebra_promote_operation($op, X, Y)
end
function Base.$op(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output($op, X, Y), $op, X, Y)
end
end
end

function Base.:*(args::Vararg{AlgebraElement,N}) where {N}
return MA.operate_to!(_preallocate_output(*, args...), *, args...)
end

Base.:^(a::AlgebraElement, p::Integer) = Base.power_by_squaring(a, p)

# mutable API
Expand Down
3 changes: 3 additions & 0 deletions src/coefficients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ Base.iszero(ac::AbstractCoefficients) = isempty(keys(ac))

Base.similar(ac::AbstractCoefficients) = similar(ac, valtype(ac))

similar_type(::Type{<:Vector}, ::Type{T}) where {T} = Vector{T}
similar_type(::Type{<:SparseArrays.SparseVector{C,I}}, ::Type{T}) where {C,I,T} = SparseArrays.SparseVector{T,I}

Check warning on line 42 in src/coefficients.jl

View check run for this annotation

Codecov / codecov/patch

src/coefficients.jl#L42

Added line #L42 was not covered by tests

"""
canonical(ac::AbstractCoefficients)
Compute the canonical form of `ac` (e.g. grouping coefficients together, etc).
Expand Down
7 changes: 7 additions & 0 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ _similar(x::Tuple, ::Type{T}) where {T} = Vector{T}(undef, length(x))
_similar(x) = similar(x)
_similar(x, ::Type{T}) where {T} = similar(x, T)

_similar_type(::Type{<:Tuple}, ::Type{T}) where {T} = Vector{T}
_similar_type(::Type{V}, ::Type{T}) where {V,T} = similar_type(V, T)

function similar_type(::Type{SparseCoefficients{K,V,Vk,Vv}}, ::Type{T}) where {K,V,Vk,Vv,T}
return SparseCoefficients{K,T,_similar_type(Vk, K),_similar_type(Vv, T)}
end

function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T}
return SparseCoefficients(_similar(s.basis_elements), _similar(s.values, T))
end
Expand Down
14 changes: 13 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct StarAlgebra{O,T,B<:AbstractBasis{T}} <: AbstractStarAlgebra{O,T}
end

basis(A::StarAlgebra) = A.basis
MA.promote_operation(::typeof(basis), ::Type{StarAlgebra{O,T,B}}) where {O,T,B} = B
object(A::StarAlgebra) = A.object

struct AlgebraElement{A,T,V} <: MA.AbstractMutable
Expand All @@ -33,14 +34,21 @@ struct AlgebraElement{A,T,V} <: MA.AbstractMutable
end

Base.parent(a::AlgebraElement) = a.parent
Base.eltype(a::AlgebraElement) = valtype(coeffs(a))
Base.eltype(::Type{A}) where {A<:AlgebraElement} = valtype(MA.promote_operation(coeffs, A))
Base.eltype(a::AlgebraElement) = eltype(typeof(a))
function MA.promote_operation(::typeof(coeffs), ::Type{AlgebraElement{A,T,V}}) where {A,T,V}
return V
end
coeffs(a::AlgebraElement) = a.coeffs
function coeffs(x::AlgebraElement, b::AbstractBasis)
return coeffs(coeffs(x), basis(x), b)
end
function adjoint_coeffs(a::AlgebraElement, target::AbstractBasis)
return adjoint_coeffs(coeffs(a), target, basis(a))
end
function MA.promote_operation(::typeof(basis), ::Type{<:AlgebraElement{A}}) where {A}
return MA.promote_operation(basis, A)
end
basis(a::AlgebraElement) = basis(parent(a))

function AlgebraElement(coeffs, A::AbstractStarAlgebra)
Expand Down Expand Up @@ -106,6 +114,10 @@ end

(A::AbstractStarAlgebra)(x::Number) = x * one(A)

function similar_type(::Type{AlgebraElement{A,T,V}}, ::Type{C}) where {A,T,V,C}
return AlgebraElement{A,C,similar_type(V, C)}
end

function Base.similar(X::AlgebraElement, T = eltype(X))
return AlgebraElement(similar(coeffs(X), T), parent(X))
end
Expand Down
16 changes: 12 additions & 4 deletions test/caching_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
end
end

function _test_op(op, a, b)
result = @inferred op(a, b)
@test typeof(result) == MA.promote_operation(op, typeof(a), typeof(b))
return result
end

@testset "FixedBasis caching && allocations" begin
alph = [:a, :b, :c]
A★ = FreeWords(alph)
Expand Down Expand Up @@ -88,8 +94,10 @@ end
(1,),
)
Z = AlgebraElement(z, fRG)
@test Z + Z == 2 * Z
@test Z + Z == Y + Y
@test Y + Z == Y + Y
@test Z + Y == Y + Y
@test _test_op(+, Z, Z) == _test_op(*, 2, Z)
@test _test_op(+, Z, Z) == _test_op(+, Y, Y)
@test _test_op(+, Y, Z) == _test_op(+, Y, Y)
@test _test_op(+, Z, Y) == _test_op(+, Y, Y)
@test _test_op(-, Z, Z) == _test_op(*, 0, Z)
@test _test_op(-, Z, Z) == _test_op(-, Y, Z)
end
3 changes: 3 additions & 0 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ end
A★ = FreeWords(alph)
B = SA.DiracBasis(A★)
RG = StarAlgebra(A★, B)
@test typeof(@inferred basis(RG)) == MA.promote_operation(basis, typeof(RG))

@test typeof(zero(RG)) == typeof(RG(0))
@test typeof(one(RG)) == typeof(RG(1))
Expand All @@ -33,6 +34,8 @@ end
@test AlgebraElement(x, RG) isa AlgebraElement

X = AlgebraElement(x, RG)
@test typeof(@inferred coeffs(X)) == MA.promote_operation(coeffs, typeof(X))
@test typeof(@inferred basis(X)) == MA.promote_operation(basis, typeof(X))

@test AlgebraElement{Float64}(X) isa AlgebraElement
Y = AlgebraElement{Float64}(X)
Expand Down

0 comments on commit 39e9646

Please sign in to comment.