From 5e5f235e99a490005611b7a22474d6fb9c7ab7ab Mon Sep 17 00:00:00 2001 From: mattsignorelli Date: Thu, 18 Apr 2024 14:30:56 -0400 Subject: [PATCH] clean up of basic operators, added in-place add,sub,mul,div --- docs/src/man/o_all.md | 4 +- src/GTPSA.jl | 6 +- src/operators.jl | 301 +++++++++++++----------------------------- src/utils.jl | 6 +- 4 files changed, 99 insertions(+), 218 deletions(-) diff --git a/docs/src/man/o_all.md b/docs/src/man/o_all.md index 77db353f..bd7fc9d8 100644 --- a/docs/src/man/o_all.md +++ b/docs/src/man/o_all.md @@ -12,12 +12,12 @@ conj, angle, complex, promote_rule, getindex, setindex!, ==, <, `zeros` and `ones` are overloaded from Base so that allocated `TPS`/`ComplexTPS`s are placed in each element. If we didn't explicity overload these functions, every element would correspond to the exact same heap-allocated TPS, which is problematic when setting individual monomial coefficients of the same TPS. `GTPSA.jl` overloads (and exports) the following functions from the corresponding packages: - **`LinearAlgebra`**: `norm` + **`LinearAlgebra`**: `norm`, `mul!` **`SpecialFunctions`**: `erf`, `erfc` `GTPSA.jl` also provides the following functions NOT included in Base or any of the above packages: ``` -unit, sinhc, asinc, asinhc, polar, rect +add!, sub!, div!, unit, sinhc, asinc, asinhc, polar, rect ``` If there is a mathematical function in Base which you'd like and is not included in the above list, feel free to submit an [issue](https://github.com/bmad-sim/GTPSA.jl/issues). \ No newline at end of file diff --git a/src/GTPSA.jl b/src/GTPSA.jl index 079d95b1..0c5fcb5e 100644 --- a/src/GTPSA.jl +++ b/src/GTPSA.jl @@ -62,7 +62,7 @@ import Base: +, show, copy! -import LinearAlgebra: norm +import LinearAlgebra: norm, mul! import SpecialFunctions: erf, erfc using GTPSA_jll, Printf, PrettyTables @@ -83,6 +83,10 @@ export rect, clear!, complex!, + add!, + sub!, + mul!, + div!, # Monomial as TPS creators: vars, diff --git a/src/operators.jl b/src/operators.jl index 35f2db95..cd64a228 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -3,6 +3,10 @@ function copy!(t::TPS, t1::TPS) mad_tpsa_copy!(t1.tpsa, t.tpsa) end +function copy!(ct::ComplexTPS, t1::TPS) + mad_ctpsa_cplx!(t1.tpsa, Base.unsafe_convert(Ptr{RTPSA}, C_NULL), ct.tpsa) +end + function copy!(ct::ComplexTPS, ct1::ComplexTPS) mad_ctpsa_copy!(ct1.tpsa, ct.tpsa) end @@ -249,255 +253,128 @@ function isequal(t1::TPS, ct1::ComplexTPS)::Bool return isequal(ct1, t1) end - # --- add --- -# TPS: -function +(t1::TPS, t2::TPS)::TPS - t = zero(t1) - mad_tpsa_add!(t1.tpsa, t2.tpsa, t.tpsa) - return t -end +# TPS, TPS: +add!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_add!(a, b, c) +add!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_add!(a, b, c) +add!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_addt!(b, a, c) +add!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_addt!(a, b, c) -function +(t1::TPS, a::Real)::TPS - t = TPS(t1) - mad_tpsa_set0!(t.tpsa, 1., convert(Float64,a)) - return t -end - -function +(a::Real, t1::TPS)::TPS - return t1 + a -end - -# ComplexTPS: -function +(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_add!(ct1.tpsa, ct2.tpsa, ct.tpsa) - return ct -end +add!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = add!(t1.tpsa, t2.tpsa, t.tpsa) -function +(ct1::ComplexTPS, a::Number)::ComplexTPS - ct = ComplexTPS(ct1) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, a)) - return ct -end +# TPS, scalar: +set0!(t::Ptr{RTPSA}, a::Float64, b::Float64) = mad_tpsa_set0!(t, a, b) +set0!(t::Ptr{CTPSA}, a::ComplexF64, b::ComplexF64) = mad_ctpsa_set0!(t, a, b) -function +(a::Number, ct1::ComplexTPS)::ComplexTPS - return ct1 + a +function add!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) + copy!(t, t1) + set0!(t.tpsa, convert(numtype(t), 1), convert(numtype(t), a)) end -# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS: -function +(ct1::ComplexTPS, t1::TPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_addt!(ct1.tpsa, t1.tpsa, ct.tpsa) - return ct -end +add!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = add!(t, t1, a) -function +(t1::TPS, ct1::ComplexTPS)::ComplexTPS - return ct1 + t1 +for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS)) +@eval begin +function +(t1::$t[1], t2::$t[2]) + use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2) + t = (promote_type(typeof(t1),typeof(t2)))(use=use) + add!(t, t1, t2) + return t end - -function +(t1::TPS, a::Complex)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, a)) - return ct end - -function +(a::Complex, t1::TPS)::ComplexTPS - return t1 + a end # --- sub --- -# TPS: -function -(t1::TPS, t2::TPS)::TPS - t = zero(t1) - mad_tpsa_sub!(t1.tpsa, t2.tpsa, t.tpsa) - return t -end - +# TPS, TPS: +sub!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_sub!(a, b, c) +sub!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_sub!(a, b, c) +sub!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_subt!(a, b, c) +sub!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_tsub!(a, b, c) -function -(t1::TPS, a::Real)::TPS - t = TPS(t1) - mad_tpsa_set0!(t.tpsa, 1., convert(Float64, -a)) - return t -end +sub!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = sub!(t1.tpsa, t2.tpsa, t.tpsa) -function -(a::Real, t1::TPS)::TPS - t = zero(t1) - mad_tpsa_scl!(t1.tpsa, -1., t.tpsa) - mad_tpsa_set0!(t.tpsa, 1., convert(Float64, a)) - return t +# TPS, scalar: +scl!(a::Ptr{RTPSA}, v::Float64, c::Ptr{RTPSA}) = mad_tpsa_scl!(a, v, c) +scl!(a::Ptr{CTPSA}, v::ComplexF64, c::Ptr{CTPSA}) = mad_ctpsa_scl!(a, v, c) +function scl!(a::Ptr{RTPSA}, v::ComplexF64, c::Ptr{CTPSA}) + mad_ctpsa_cplx!(a, Base.unsafe_convert(Ptr{RTPSA},C_NULL), c) + scl!(c, v, c) end -# ComplexTPS: -function -(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_sub!(ct1.tpsa, ct2.tpsa, ct.tpsa) - return ct -end - -function -(ct1::ComplexTPS, a::Number)::ComplexTPS - ct = ComplexTPS(ct1) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, -a)) - return ct -end - -function -(a::Number, ct1::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64, -1), ct.tpsa) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64,a)) - return ct -end +sub!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = add!(t, t1, -a) -# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS: -function -(ct1::ComplexTPS, t1::TPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_subt!(ct1.tpsa, t1.tpsa, ct.tpsa) - return ct +function sub!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) + scl!(t1.tpsa, convert(numtype(t), -1.), t.tpsa) + set0!(t.tpsa, convert(numtype(t), 1.), convert(numtype(t), a)) end -function -(t1::TPS, ct1::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_tsub!(t1.tpsa, ct1.tpsa, ct.tpsa) - return ct +for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS)) +@eval begin +function -(t1::$t[1], t2::$t[2]) + use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2) + t = (promote_type(typeof(t1),typeof(t2)))(use=use) + sub!(t, t1, t2) + return t end - -function -(t1::TPS, a::Complex)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64, -a)) - return ct end - -function -(a::Complex, t1::TPS)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64, -1), ct.tpsa) - mad_ctpsa_set0!(ct.tpsa, convert(ComplexF64, 1), convert(ComplexF64,a)) - return ct end - # --- mul --- -# TPS: -function *(t1::TPS, t2::TPS)::TPS - t = zero(t1) - mad_tpsa_mul!(t1.tpsa, t2.tpsa, t.tpsa) +# TPS, TPS: +mul!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_mul!(a, b, c) +mul!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mul!(a, b, c) +mul!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mult!(b, a, c) +mul!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_mult!(a, b, c) + +mul!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = mul!(t1.tpsa, t2.tpsa, t.tpsa) + +# TPS, scalar: +mul!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = scl!(t1.tpsa, convert(numtype(t), a), t.tpsa) +mul!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = mul!(t, t1, a) + +for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS)) +@eval begin +function *(t1::$t[1], t2::$t[2]) + use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2) + t = (promote_type(typeof(t1),typeof(t2)))(use=use) + mul!(t, t1, t2) return t end - -function *(t1::TPS, a::Real)::TPS - t = zero(t1) - mad_tpsa_scl!(t1.tpsa, convert(Float64, a), t.tpsa) - return t -end - -function *(a::Real, t1::TPS)::TPS - return t1 * a -end - -# ComplexTPS: -function *(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_mul!(ct1.tpsa, ct2.tpsa, ct.tpsa) - return ct -end - -function *(ct1::ComplexTPS, a::Number)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64,a), ct.tpsa) - return ct -end - -function *(a::Number, ct1::ComplexTPS)::ComplexTPS - return ct1 * a -end - -# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS: -function *(ct1::ComplexTPS, t1::TPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_mult!(ct1.tpsa, t1.tpsa, ct.tpsa) - return ct end - -function *(t1::TPS, ct1::ComplexTPS)::ComplexTPS - return ct1 * t1 -end - -function *(t1::TPS, a::Complex)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64,a), ct.tpsa) - return ct -end - -function *(a::Complex, t1::TPS)::ComplexTPS - return t1 * a end # --- div --- -# TPS: -function /(t1::TPS, t2::TPS)::TPS - t = zero(t1) - mad_tpsa_div!(t1.tpsa, t2.tpsa, t.tpsa) +div!(a::Ptr{RTPSA}, b::Ptr{RTPSA}, c::Ptr{RTPSA}) = mad_tpsa_div!(a, b, c) +div!(a::Ptr{CTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_div!(a, b, c) +div!(a::Ptr{CTPSA}, b::Ptr{RTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_divt!(a, b, c) +div!(a::Ptr{RTPSA}, b::Ptr{CTPSA}, c::Ptr{CTPSA}) = mad_ctpsa_tdiv!(a, b, c) + +div!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, t2::Union{TPS,ComplexTPS}) = div!(t1.tpsa, t2.tpsa, t.tpsa) + +# TPS, scalar: +inv!(a::Ptr{RTPSA}, v::Float64, c::Ptr{RTPSA}) = mad_tpsa_inv!(a, v, c) +inv!(a::Ptr{CTPSA}, v::ComplexF64, c::Ptr{CTPSA}) = mad_ctpsa_inv!(a, v, c) +function inv!(a::Ptr{RTPSA}, v::ComplexF64, c::Ptr{CTPSA}) + mad_ctpsa_cplx!(a, Base.unsafe_convert(Ptr{RTPSA}, C_NULL), c) + inv!(c, v, c) +end + +div!(t::Union{TPS,ComplexTPS}, t1::Union{TPS,ComplexTPS}, a::Number) = mul!(t, t1, 1/a) +div!(t::Union{TPS,ComplexTPS}, a::Number, t1::Union{TPS,ComplexTPS}) = inv!(t1.tpsa, convert(numtype(t), a), t.tpsa) + +for t = ((TPS,TPS),(TPS,Real),(Real,TPS),(TPS,Complex),(Complex,TPS),(ComplexTPS,TPS),(TPS,ComplexTPS),(ComplexTPS,ComplexTPS),(ComplexTPS, Number), (Number, ComplexTPS)) +@eval begin +function /(t1::$t[1], t2::$t[2]) + use = $(t[1] == TPS || t[1] == ComplexTPS ? :t1 : :t2) + t = (promote_type(typeof(t1),typeof(t2)))(use=use) + div!(t, t1, t2) return t end - -function /(t1::TPS, a::Real)::TPS - t = zero(t1) - mad_tpsa_scl!(t1.tpsa, convert(Float64, 1/a), t.tpsa) - return t end - -function /(a::Real, t1::TPS)::TPS - t = zero(t1) - mad_tpsa_inv!(t1.tpsa, convert(Float64,a), t.tpsa) - return t end -# ComplexTPS: -function /(ct1::ComplexTPS, ct2::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_div!(ct1.tpsa, ct2.tpsa, ct.tpsa) - return ct -end - -function /(ct1::ComplexTPS, a::Number)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_scl!(ct1.tpsa, convert(ComplexF64, 1/a), ct.tpsa) - return ct -end - -function /(a::Number, ct1::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_inv!(ct1.tpsa, convert(ComplexF64, a), ct.tpsa) - return ct -end - -# TPS to ComplexTPS promotion, w/o creating temp ComplexTPS: -function /(ct1::ComplexTPS, t1::TPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_divt!(ct1.tpsa, t1.tpsa, ct.tpsa) - return ct -end - -function /(t1::TPS, ct1::ComplexTPS)::ComplexTPS - ct = zero(ct1) - mad_ctpsa_tdiv!(t1.tpsa, ct1.tpsa, ct.tpsa) - return ct -end - -function /(t1::TPS, a::Complex)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_scl!(ct.tpsa, convert(ComplexF64, 1/a), ct.tpsa) - return ct -end - -function /(a::Complex, t1::TPS)::ComplexTPS - ct = ComplexTPS(t1) - mad_ctpsa_inv!(ct.tpsa, convert(ComplexF64, a), ct.tpsa) - return ct -end - - # --- pow --- # TPS: function ^(t1::TPS, t2::TPS)::TPS diff --git a/src/utils.jl b/src/utils.jl index 9f8b76cc..5d556edb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -133,9 +133,9 @@ Complex{TPS}(t1::TPS, t2::TPS) = error("ComplexTPS can only be defined as an Abs Complex{TPS}(t1::TPS, a::Real) = error("ComplexTPS can only be defined as an AbstractComplex type (to be implemented in Julia PR #35587). If this error was reached without explicitly attempting to create a Complex{TPS}, please submit an issue to GTPSA.jl with an example.") Complex{TPS}(a::Real, t1::TPS) = error("ComplexTPS can only be defined as an AbstractComplex type (to be implemented in Julia PR #35587). If this error was reached without explicitly attempting to create a Complex{TPS}, please submit an issue to GTPSA.jl with an example.") -promote_rule(::Type{TPS}, ::Union{Type{<:AbstractFloat}, Type{<:Integer}, Type{<:Rational}, Type{<:AbstractIrrational}}) = TPS -promote_rule(::Type{ComplexTPS}, ::Union{Type{Complex{<:Real}},Type{<:AbstractFloat}, Type{<:Integer}, Type{<:Rational}, Type{<:AbstractIrrational}}) = ComplexTPS -promote_rule(::Type{TPS}, ::Union{Type{ComplexTPS}, Type{Complex{<:Real}}}) = ComplexTPS +promote_rule(::Type{TPS}, ::Type{T}) where {T<:Real} = TPS #::Union{Type{<:AbstractFloat}, Type{<:Integer}, Type{<:Rational}, Type{<:AbstractIrrational}}) = TPS +promote_rule(::Type{ComplexTPS}, ::Type{T}) where {T<:Number} = ComplexTPS #::Union{Type{Complex{<:Real}},Type{<:AbstractFloat}, Type{<:Integer}, Type{<:Rational}, Type{<:AbstractIrrational}}) = ComplexTPS +promote_rule(::Type{TPS}, ::Type{T}) where {T<:Number}= ComplexTPS # Handle bool which is special for some reason +(t::TPS, z::Complex{Bool}) = t + Complex{Int}(z)