From 9a2f7bbec515c55e5594feef7928670cb811169c Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:53:06 +0100 Subject: [PATCH] Remove chainrule and test for `SqMahalanobis` (#539) --- Project.toml | 2 +- src/chainrules.jl | 22 ---------------------- test/chainrules.jl | 9 --------- 3 files changed, 1 insertion(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 19247efdf..c46c5c334 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.59" +version = "0.10.60" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index eebdf95b5..d31ec97d1 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -121,28 +121,6 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) return val, evaluate_pullback end -## Reverse Rules SqMahalanobis - -function ChainRulesCore.rrule( - dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector -) - d = dist(a, b) - function SqMahalanobis_pullback(Δ::Real) - a_b = a - b - ∂qmat = InplaceableThunk( - X̄ -> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ) - ) - ∂a = InplaceableThunk( - X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b) - ) - ∂b = InplaceableThunk( - X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b) - ) - return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b - end - return d, SqMahalanobis_pullback -end - ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) diff --git a/test/chainrules.jl b/test/chainrules.jl index 03c2c3b1f..5c3c5766b 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -3,8 +3,6 @@ x = rand(rng, 5) y = rand(rng, 5) r = rand(rng, 5) - Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) - @assert isposdef(Q) compare_gradient(:Zygote, [x, y]) do xy Euclidean()(xy[1], xy[2]) @@ -21,11 +19,4 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end - if VERSION < v"1.6" - @test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6" - else - compare_gradient(:Zygote, [Q, x, y]) do Qxy - SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) - end - end end