Skip to content

Commit

Permalink
attempt a fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 17, 2023
1 parent e787af4 commit 0f93c1c
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/diffKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ for higher order derivatives partial can be any iterable, i.e.
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
```
"""
struct DiffPt{Dim}
struct DiffPt
pos # the actual position
partial
end

DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
DiffPt(x; partial=()) = DiffPt(x, partial) # convenience constructor

"""
partial(fun, idx)
Expand All @@ -34,10 +34,8 @@ Return ∂ᵢf where
"""
function partial(fun, idx)
return x -> FD.derivative(0) do dx
y = similar(x)
y = copyto!(y, x)
y[idx] += dx
fun(y)
dim = length(x)
fun(x .+ dx * OneHotVector(idx, dim))
end
end

Expand All @@ -58,23 +56,23 @@ end
Take the partial derivative of a function with two dim-dimensional inputs,
i.e. 2*dim dimensional input
"""
function partial(k, dim; partials_x=(), partials_y=())
local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x)
return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y)
function partial(k; partials_x=(), partials_y=())
local f(x, y) = partial(t -> k(t, y), partials_x...)(x)
return (x, y) -> partial(t -> f(x, t), partials_y...)(y)
end

"""
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
_evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
implements `(k::T)(x::DiffPt, y::DiffPt)` for all kernel types. But since
generics are not allowed in the syntax above by the dispatch system, this
redirection over `_evaluate` is necessary
unboxes the partial instructions from DiffPt and applies them to k,
evaluates them at the positions of DiffPt
"""
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel}
return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos)
function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
return partial(k, partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos)
end

#=
Expand All @@ -101,7 +99,7 @@ for T in [
NormalizedKernel,
KernelTensorProduct
] #subtypes(Kernel)
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)
(k::T)(x::DiffPt, y::DiffPt)= _evaluate(k, x, y)
(k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y))
(k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y)
end

0 comments on commit 0f93c1c

Please sign in to comment.