From 0f93c1c69ddfb916b58f819db3493ff041b3e7b0 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 16:09:22 +0200 Subject: [PATCH] attempt a fix --- src/diffKernel.jl | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index fc8ff197c..425d33e14 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -18,12 +18,12 @@ for higher order derivatives partial can be any iterable, i.e. k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) ``` """ -struct DiffPt{Dim} +struct DiffPt pos # the actual position partial end -DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor +DiffPt(x; partial=()) = DiffPt(x, partial) # convenience constructor """ partial(fun, idx) @@ -34,10 +34,8 @@ Return ∂ᵢf where """ function partial(fun, idx) return x -> FD.derivative(0) do dx - y = similar(x) - y = copyto!(y, x) - y[idx] += dx - fun(y) + dim = length(x) + fun(x .+ dx * OneHotVector(idx, dim)) end end @@ -58,23 +56,23 @@ end Take the partial derivative of a function with two dim-dimensional inputs, i.e. 2*dim dimensional input """ -function partial(k, dim; partials_x=(), partials_y=()) - local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x) - return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y) +function partial(k; partials_x=(), partials_y=()) + local f(x, y) = partial(t -> k(t, y), partials_x...)(x) + return (x, y) -> partial(t -> f(x, t), partials_y...)(y) end """ - _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} + _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel} -implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since +implements `(k::T)(x::DiffPt, y::DiffPt)` for all kernel types. But since generics are not allowed in the syntax above by the dispatch system, this redirection over `_evaluate` is necessary unboxes the partial instructions from DiffPt and applies them to k, evaluates them at the positions of DiffPt """ -function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel} - return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos) +function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel} + return partial(k, partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos) end #= @@ -101,7 +99,7 @@ for T in [ NormalizedKernel, KernelTensorProduct ] #subtypes(Kernel) - (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) - (k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) - (k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) + (k::T)(x::DiffPt, y::DiffPt)= _evaluate(k, x, y) + (k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y) end