From f31f1606bc6b375974cd0242d7cbb5ebe7558cd9 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 23 Jan 2024 21:21:06 +0100 Subject: [PATCH 1/3] Optimize `pairwise` for `Sinus` and scalar inputs --- src/distances/sinus.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 2fdadfcf9..1d2834c3c 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -21,3 +21,7 @@ Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T end return sum(abs2, sinpi.(a - b) ./ d.r) end + +# Optimizations for scalar inputs (avoiding allocations) +pairwise(d::Sinus, x::AbstractVector{<:Real}) = pairwise(d, x, x) +pairwise(d::Sinus, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = abs2.(sinpi.(x .- y') ./ only(d.r)) From 912cc49dd87838e08606604ca034d06c40625d70 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:21:06 +0100 Subject: [PATCH 2/3] Improve formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/distances/sinus.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 1d2834c3c..706c1f5f8 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -24,4 +24,6 @@ end # Optimizations for scalar inputs (avoiding allocations) pairwise(d::Sinus, x::AbstractVector{<:Real}) = pairwise(d, x, x) -pairwise(d::Sinus, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = abs2.(sinpi.(x .- y') ./ only(d.r)) +function pairwise(d::Sinus, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) + return abs2.(sinpi.(x .- y') ./ only(d.r)) +end From f990586d3fdac6196de844e0f8793d6d7048a982 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 23 Jan 2024 21:21:06 +0100 Subject: [PATCH 3/3] Avoid method ambiguities --- src/distances/sinus.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 706c1f5f8..d6ba786a4 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -27,3 +27,9 @@ pairwise(d::Sinus, x::AbstractVector{<:Real}) = pairwise(d, x, x) function pairwise(d::Sinus, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) return abs2.(sinpi.(x .- y') ./ only(d.r)) end +pairwise(d::Sinus, x::RowVecs) = Distances_pairwise(d, x.X; dims=1) +pairwise(d::Sinus, x::ColVecs) = Distances_pairwise(d, x.X; dims=2) +pairwise(d::Sinus, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1) +pairwise(d::Sinus, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) +pairwise(d::Sinus, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) +pairwise(d::Sinus, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)