Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jun 12, 2024
1 parent 6dc55ec commit 2ec2cc4
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 7 deletions.
13 changes: 13 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ function _preallocate_output(X::AlgebraElement, Y::AlgebraElement, op)
return similar(X, T)
end

function MA.promote_operation(::typeof(similar), ::Type{<:Vector}, ::Type{T}) where {T}
return Vector{T}
end

# module structure:

Base.:*(X::AlgebraElement, a::Number) = a * X
Expand All @@ -27,12 +31,21 @@ function Base.:div(X::AlgebraElement, a::Number)
return MA.operate_to!(_preallocate_output(X, a, div), div, X, a)
end

function MA.promote_operation(
op::Union{typeof(+),typeof(-)},
::Type{AlgebraElement{A,T,VT}},
::Type{AlgebraElement{A,S,VS}},
) where {A,T,VT,S,VS}
U = MA.promote_operation(op, T, S)
return AlgebraElement{A,U,MA.promote_operation(similar, VT, U)}
end
function Base.:+(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, +), +, X, Y)
end
function Base.:-(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, -), -, X, Y)
end

function Base.:*(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, *), *, X, Y)
end
Expand Down
89 changes: 82 additions & 7 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,108 @@ __prints_with_minus(x::Real) = x < 0
__needs_parens(::Any) = false
__needs_parens(a::AlgebraElement) = true

function _coeff_elt_print(io, c, elt)
print(io, c, '·')
# `Int`, `Float64` don't support MIME"text/latex".
# We could add a check with `showable` if a `Real` subtype supports it and
# the feature is requested.
print_coefficient(io::IO, mime, coeff::Real) = print(io, coeff)
# Scientific notation does not display well in LaTeX so we rewrite it
function print_coefficient(io::IO, mime::MIME"text/latex", coeff::AbstractFloat)
s = string(coeff)
if occursin('e', s)
s = replace(s, 'e' => " \\cdot 10^{") * '}'
end
return print(io, s)
end

trim_LaTeX(::MIME, s::AbstractString) = s

function trim_LaTeX(::MIME"text/latex", s::AbstractString)
i = firstindex(s)
j = lastindex(s)
while true
if i < j && isspace(s[i])
i = nextind(s, i)
elseif i < j && isspace(s[j])
j = prevind(s, j)
elseif i < j && s[i] == '$' && s[j] == '$'
i = nextind(s, i)
j = prevind(s, j)
elseif i < j && (
(s[i:nextind(s, i)] == "\\(" && s[prevind(s, j):j] == "\\)") ||
(s[i:nextind(s, i)] == "\\[" && s[prevind(s, j):j] == "\\]")
)
i = nextind(s, i, 2)
j = prevind(s, j, 2)
else
return s[i:j]
end
end
end

# JuMP expressions supports LaTeX output so `showable` will return `true`
# for them. It is important for anonymous variables to display properly as well:
# https://github.com/jump-dev/SumOfSquares.jl/issues/256
# Since they add `$$` around it, we need to trim it with `trim_LaTeX`
function print_coefficient(io::IO, mime, coeff)
print(io, "(")
if showable(mime, coeff)
print(io, trim_LaTeX(mime, sprint(show, mime, coeff)))
else
show(io, coeff)
end
return print(io, ")")
end

_print_dot(io, ::MIME"text/latex") = print(io, " \\cdot ")
_print_dot(io, ::MIME) = print(io, '')

function _coeff_elt_print(io, mime, c, elt)
print_coefficient(io, mime, c)
_print_dot(io, mime)
__needs_parens(elt) && print(io, '(')
print(io, elt)
print(io, trim_LaTeX(mime, sprint(show, mime, elt)))
__needs_parens(elt) && print(io, ')')
return
end

function Base.show(io::IO, a::AlgebraElement)
Base.print(io::IO, a::AlgebraElement) = show(io, MIME"text/print"(), a)
Base.show(io::IO, a::AlgebraElement) = show(io, MIME"text/plain"(), a)

function Base.show(io::IO, mime::MIME"text/latex", a::AlgebraElement)
print(io, "\$\$ ")
_show(io, mime, a)
return print(io, " \$\$")
end

# If the MIME is not specified, IJulia thinks that it supports images, ...
# and then use the result of show and tries to interpret it as an svg, ...
# We need the two methods to avoid ambiguity
function Base.show(io::IO, mime::MIME"text/plain", a::AlgebraElement)
return _show(io, mime, a)
end
function Base.show(io::IO, mime::MIME"text/print", a::AlgebraElement)
return _show(io, mime, a)
end

function _show(io::IO, mime, a::AlgebraElement)
A = parent(a)
if iszero(a)
T = valtype(coeffs(a))
_coeff_elt_print(io, zero(T), first(basis(A)))
_coeff_elt_print(io, mime, zero(T), first(basis(A)))
else
_first = true
for (idx, value) in nonzero_pairs(coeffs(a))
c, elt = value, basis(A)[idx]
if _first
_coeff_elt_print(io, c, elt)
_coeff_elt_print(io, mime, c, elt)
_first = false
else
if __prints_with_minus(c)
print(io, ' ')
else
print(io, ' ', '+', ' ')
end
_coeff_elt_print(io, c, elt)
_coeff_elt_print(io, mime, c, elt)
end
end
end
Expand Down
8 changes: 8 additions & 0 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ function Base.zero(sc::SparseCoefficients)
return SparseCoefficients(empty(keys(sc)), empty(values(sc)))
end

function MA.promote_operation(
::typeof(similar),
::Type{SparseCoefficients{K,V,Vk,Vv}},
::Type{T},
) where {K,V,Vk,Vv,T}
return SparseCoefficients{K,T,Vk,MA.promote_operation(similar, Vv, T)}
end

function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T}
return SparseCoefficients(similar(s.basis_elements), similar(s.values, T))
end
Expand Down

0 comments on commit 2ec2cc4

Please sign in to comment.