Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DiffPt for the covariance function of derivatives #508

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want a dependency on ForwardDiff. We tried hard to avoid it so far.

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe this should be a plugin somehow, but writing this as a plugin would have caused a bunch of boilerplate to review and I wanted to make it easier to grasp the core idea

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe AbstractDifferentiaton.jl would be the right abstraction?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's not ready for proper use. But hopefully it will at some point.

Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down
3 changes: 3 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou
export IndependentMOKernel,
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel

export DiffPt

# Reexports
export tensor, ⊗, compose

Expand Down Expand Up @@ -125,6 +127,7 @@ include("chainrules.jl")
include("zygoterules.jl")

include("TestUtils.jl")
include("diffKernel.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernels are contained in the kernel subfolder.

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it is not a new kernel - it is extending the functionality of all kernels which is why I did not put it here. But maybe it makes sense to put it there anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be a new kernel. That's how you can nicely fit it into the KernelFunctions ecosystem.


function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
Expand Down
108 changes: 108 additions & 0 deletions src/diffKernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using OneHotArrays: OneHotVector
import ForwardDiff as FD
import LinearAlgebra as LA

"""
DiffPt(x; partial=())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate type needed? Wouldn't it be better to use the existing input formats for multi-output kernels?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find the existing input formats? https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/design/#inputs_for_multiple_outputs here it says that you would use Tuple{T, Int} and while you could in principle convert a partial tuple to an Int, you reall don't want to.

I mean even if we are gracious and start counting with zero to make this less of a mess:

() -> 0 # no derivative
(1) -> 1 # partial derivative in direction 1
(2) -> 2
...
(dim) -> dim # partial derivative in direction dim
(1,1) -> dim+1 # twice partial derivative in direction 1 
(1,2) -> dim +2
...
(k, j) -> k * dim + j

So the user would have to do this conversion from carthesian coordinates to linear coordinates. And then the implementation would revert this transformation back from linear coordinates to cartesian coordinates. Cartesian coordinates with ex ante unknown tuple length that is.

That all seems annoying without being really necessary. But sure - you could do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you just need X \times \{0, \ldots, dim\} to represent all desired partials?

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is X? (btw mathjax is enabled: $ works for inline and ```math works for multiline maths.)

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is a bit tricky

$$\begin{aligned} f: \mathbb{R}^d\to \mathbb{R}\\\ \nabla f(x) \in \mathbb{R}^d\\\ \nabla^2 f(x) \in \mathbb{R}^{d^2} \end{aligned}$$

So $(f(x), \nabla f(x), \nabla^2 f(x)) \in \mathbb{R} \times \mathbb{R}^d \times \mathbb{R}^{d\times d} = \mathbb{R}^{1+d + d^2}$

So more generally $(f(x), \nabla f(x), ..., f^{(n)}(x)) \in \mathbb{R}^m$ with

$$m=\sum_{k=0}^{n} d^k = \frac{n(n+1)}2$$

In essence a variable length tuple carries more information than just the longest tuple. Plus there is no longest tuple.


For a covariance kernel k of GP Z, i.e.
```julia
k(x,y) # = Cov(Z(x), Z(y)),
```
a DiffPt allows the differentiation of Z, i.e.
```julia
k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y))
```
for higher order derivatives partial can be any iterable, i.e.
```julia
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
```
"""
struct DiffPt{Dim}
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
pos # the actual position
partial
end

DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor

"""
Take the partial derivative of a function `fun` with input dimesion `dim`.
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
"""
function partial(fun, dim, partials=())
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
if !isnothing(local next = iterate(partials))
idx, state = next
return partial(
x -> FD.derivative(0) do dx
fun(x .+ dx * OneHotVector(idx, dim))
end,
dim,
Base.rest(partials, state),
)
end
return fun
end

"""
Take the partial derivative of a function with two dim-dimensional inputs,
i.e. 2*dim dimensional input
"""
function partial(k, dim; partials_x=(), partials_y=())
local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x)
return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y)
end




"""
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}

implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
generics are not allowed in the syntax above by the dispatch system, this
redirection over `_evaluate` is necessary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment, I think this should not be done and a simple wrapper would be sufficient.


unboxes the partial instructions from DiffPt and applies them to k,
evaluates them at the positions of DiffPt
"""
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
return partial(
k, Dim,
partials_x=x.partial, partials_y=y.partial
)(x.pos, y.pos)
end



#=
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
```julia
(::Kernel)(::DiffPt,::DiffPt)
```
then julia would not know whether to use
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
```
Comment on lines +88 to +95
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid this hack, no kernel type T should implement

	(::T)(x,y)

and instead implement

	_evaluate(k::T, x, y)

Then there should be only a single

	(k::Kernel)(x,y) = _evaluate(k, x, y)

which all the kernels would fall back to.

This ensures that _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where T<:Kernel is always
more specialized and called.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this is a much more intrusive change so to not blow up the lines changed, I did not do this yet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also why "detect ambiguities" fails right now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should switch to _evaluate(k, x, y). Instead of changing all other kernels, I think you should just create a single wrapper of existing kernels.

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ugly thing about wrapping existing kernels is, that you essentially say: This is a different gaussian process model.

I.e.

f = GP(MaternKernel())

is a different (non-differentiable) model from

g = GP(DiffWrapper(MaternKernel())),

which is differentiable. But fundamentally the matern kernel implies that the gaussian process should always be $\lfloor \nu\rfloor$-differentiable. So f ought to be differentiable.

And with this abstraction it is. I.e. you use

x = 1:10
fx = f(x)
y = rand(fx)

to simulate f at points 1:10. If you now wanted to simulate its gradient at point 0 too, you would just have to modify its input

x = [DiffPt(0, partial=1), 1:10... ]
fx = f(x)
y0_grad, y... =  rand(fx)

and y0_grad would be the gradient in 0 as expected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note how this abstraction lets you mix and match normal points and DiffPts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note how this abstraction lets you mix and match normal points and DiffPts

You can perform exactly the same computations with a wrapper type.

The ugly thing about wrapping existing kernels is, that you essentially say: This is a different gaussian process model.

I don't view it this way. The wrapper does not change the mathematical model, it just allows you to query derivatives as well.

Copy link
Author

@FelixBenning FelixBenning May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can perform exactly the same computations with a wrapper type.

well, if the wrapped kernel has a superset of the capabilities of the original kernel (without performance cost) - why would you ever use the unwrapped kernel? So if you only ever use the wrapped kernel, then
GP(DiffWrapper(MaternKernel())
is kind of superflous. So then you would probably start to write convenience functions to get the wrapped kernel immediately, but for the wrapped kernel the compositions like +, ... are not implemented. So you would have to pass all those through. It seems like a pointless effort to get a capability, which the original kernel should already have. I mean the capability does not collide with anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not superfluous as the wrapped kernel has a different API (i.e., you have to provide different types/structure of inputs) that is more inconvenient to work with if you're not interested in evaluating those. There's a difference - but IMO it's not a mathematical one but rather regarding user experience.

As you've already noticed I think it's also just not feasible to extend every implementation of differentiable kernels out there without making implementations of kernels more inconvenient for people that do not want to work with derivatives. So clearly separating these use cases seems simpler to me from a design perspective.

The wrapper would be a Kernel as well, of course, so if the existing implementations in KernelFunctions such as sums etc. are written as generally as they were intended no new definitions for compositions are needed. You would add them only if there is a clear performance gain that outweighs code complexity.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no - you do not have to provide different types/structures of inputs. So it does not make it more inconvenient for people not interested in gradients. That is the entire point of specializing on DiffPt the kernel function still works on everything that is not a DiffPt

To avoid this hack, no kernel type T should implement
```julia
(::T)(x,y)
```
and instead implement
```julia
_evaluate(k::T, x, y)
```
Then there should be only a single
```julia
(k::Kernel)(x,y) = _evaluate(k, x, y)
```
which all the kernels would fall back to.

This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always
more specialized and call beforehand.
=#
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)
end

25 changes: 25 additions & 0 deletions test/diffKernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testset "diffKernel" begin
@testset "smoke test" begin
k = MaternKernel()
k(1,1)
k(1, DiffPt(1, partial=(1,1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
k(DiffPt([1], partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2]
end

@testset "Sanity Checks with $k" for k in [SEKernel()]
for x in [0, 1, -1, 42]
# for stationary kernels Cov(∂Z(x) , Z(x)) = 0
@test k(DiffPt(x, partial=1), x) ≈ 0

# the slope should be positively correlated with a point further down
@test k(
DiffPt(x, partial=1), # slope
x + 1e-1 # point further down
) > 0

# correlation with self should be positive
@test k(DiffPt(x, partial=1), DiffPt(x, partial=1)) > 0
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ include("test_utils.jl")
include("generic.jl")
include("chainrules.jl")
include("zygoterules.jl")
include("diffKernel.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
Expand Down