Skip to content

Commit

Permalink
Merge branch 'master' into scalar-periodic
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace authored Nov 22, 2023
2 parents dc61af1 + 9a2f7bb commit 7a88585
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.59"
version = "0.10.60"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
22 changes: 0 additions & 22 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,6 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
return val, evaluate_pullback
end

## Reverse Rules SqMahalanobis

function ChainRulesCore.rrule(
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
)
d = dist(a, b)
function SqMahalanobis_pullback::Real)
a_b = a - b
∂qmat = InplaceableThunk(
-> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ)
)
∂a = InplaceableThunk(
-> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b)
)
∂b = InplaceableThunk(
-> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b)
)
return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
end
return d, SqMahalanobis_pullback
end

## Reverse Rules for matrix wrappers

function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
Expand Down
9 changes: 0 additions & 9 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)

compare_gradient(:Zygote, [x, y]) do xy
Euclidean()(xy[1], xy[2])
Expand All @@ -21,11 +19,4 @@
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Sinus(r)(xy[1], xy[2])
end
if VERSION < v"1.6"
@test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6"
else
compare_gradient(:Zygote, [Q, x, y]) do Qxy
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
end
end
end

0 comments on commit 7a88585

Please sign in to comment.