diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 080cf77..85aaf6c 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -11,6 +11,10 @@ function _preallocate_output(X::AlgebraElement, Y::AlgebraElement, op) return similar(X, T) end +function MA.promote_operation(::typeof(similar), ::Type{<:Vector}, ::Type{T}) where {T} + return Vector{T} +end + # module structure: Base.:*(X::AlgebraElement, a::Number) = a * X @@ -27,12 +31,21 @@ function Base.:div(X::AlgebraElement, a::Number) return MA.operate_to!(_preallocate_output(X, a, div), div, X, a) end +function MA.promote_operation( + op::Union{typeof(+),typeof(-)}, + ::Type{AlgebraElement{A,T,VT}}, + ::Type{AlgebraElement{A,S,VS}}, +) where {A,T,VT,S,VS} + U = MA.promote_operation(op, T, S) + return AlgebraElement{A,U,MA.promote_operation(similar, VT, U)} +end function Base.:+(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(X, Y, +), +, X, Y) end function Base.:-(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(X, Y, -), -, X, Y) end + function Base.:*(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(X, Y, *), *, X, Y) end diff --git a/src/show.jl b/src/show.jl index a8182cf..234cb42 100644 --- a/src/show.jl +++ b/src/show.jl @@ -8,25 +8,100 @@ __prints_with_minus(x::Real) = x < 0 __needs_parens(::Any) = false __needs_parens(a::AlgebraElement) = true -function _coeff_elt_print(io, c, elt) - print(io, c, '·') +# `Int`, `Float64` don't support MIME"text/latex". +# We could add a check with `showable` if a `Real` subtype supports it and +# the feature is requested. +print_coefficient(io::IO, mime, coeff::Real) = print(io, coeff) +# Scientific notation does not display well in LaTeX so we rewrite it +function print_coefficient(io::IO, mime::MIME"text/latex", coeff::AbstractFloat) + s = string(coeff) + if occursin('e', s) + s = replace(s, 'e' => " \\cdot 10^{") * '}' + end + return print(io, s) +end + +trim_LaTeX(::MIME, s::AbstractString) = s + +function trim_LaTeX(::MIME"text/latex", s::AbstractString) + i = firstindex(s) + j = lastindex(s) + while true + if i < j && isspace(s[i]) + i = nextind(s, i) + elseif i < j && isspace(s[j]) + j = prevind(s, j) + elseif i < j && s[i] == '$' && s[j] == '$' + i = nextind(s, i) + j = prevind(s, j) + elseif i < j && ( + (s[i:nextind(s, i)] == "\\(" && s[prevind(s, j):j] == "\\)") || + (s[i:nextind(s, i)] == "\\[" && s[prevind(s, j):j] == "\\]") + ) + i = nextind(s, i, 2) + j = prevind(s, j, 2) + else + return s[i:j] + end + end +end + +# JuMP expressions supports LaTeX output so `showable` will return `true` +# for them. It is important for anonymous variables to display properly as well: +# https://github.com/jump-dev/SumOfSquares.jl/issues/256 +# Since they add `$$` around it, we need to trim it with `trim_LaTeX` +function print_coefficient(io::IO, mime, coeff) + print(io, "(") + if showable(mime, coeff) + print(io, trim_LaTeX(mime, sprint(show, mime, coeff))) + else + show(io, coeff) + end + return print(io, ")") +end + +_print_dot(io, ::MIME"text/latex") = print(io, " \\cdot ") +_print_dot(io, ::MIME) = print(io, '⋅') + +function _coeff_elt_print(io, mime, c, elt) + print_coefficient(io, mime, c) + _print_dot(io, mime) __needs_parens(elt) && print(io, '(') - print(io, elt) + print(io, trim_LaTeX(mime, sprint(show, mime, elt))) __needs_parens(elt) && print(io, ')') return end -function Base.show(io::IO, a::AlgebraElement) +Base.print(io::IO, a::AlgebraElement) = show(io, MIME"text/print"(), a) +Base.show(io::IO, a::AlgebraElement) = show(io, MIME"text/plain"(), a) + +function Base.show(io::IO, mime::MIME"text/latex", a::AlgebraElement) + print(io, "\$\$ ") + _show(io, mime, a) + return print(io, " \$\$") +end + +# If the MIME is not specified, IJulia thinks that it supports images, ... +# and then use the result of show and tries to interpret it as an svg, ... +# We need the two methods to avoid ambiguity +function Base.show(io::IO, mime::MIME"text/plain", a::AlgebraElement) + return _show(io, mime, a) +end +function Base.show(io::IO, mime::MIME"text/print", a::AlgebraElement) + return _show(io, mime, a) +end + +function _show(io::IO, mime, a::AlgebraElement) A = parent(a) if iszero(a) T = valtype(coeffs(a)) - _coeff_elt_print(io, zero(T), first(basis(A))) + _coeff_elt_print(io, mime, zero(T), first(basis(A))) else _first = true for (idx, value) in nonzero_pairs(coeffs(a)) c, elt = value, basis(A)[idx] if _first - _coeff_elt_print(io, c, elt) + _coeff_elt_print(io, mime, c, elt) _first = false else if __prints_with_minus(c) @@ -34,7 +109,7 @@ function Base.show(io::IO, a::AlgebraElement) else print(io, ' ', '+', ' ') end - _coeff_elt_print(io, c, elt) + _coeff_elt_print(io, mime, c, elt) end end end diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index bd828c9..c9ebeb0 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -42,6 +42,14 @@ function Base.zero(sc::SparseCoefficients) return SparseCoefficients(empty(keys(sc)), empty(values(sc))) end +function MA.promote_operation( + ::typeof(similar), + ::Type{SparseCoefficients{K,V,Vk,Vv}}, + ::Type{T}, +) where {K,V,Vk,Vv,T} + return SparseCoefficients{K,T,Vk,MA.promote_operation(similar, Vv, T)} +end + function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T} return SparseCoefficients(similar(s.basis_elements), similar(s.values, T)) end