-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce DiffPt for the covariance function of derivatives #508
base: master
Are you sure you want to change the base?
Changes from 3 commits
7e2b8dc
6179899
74085db
21980a9
8f74495
9c4ff2e
f2f5203
16938a9
13c3cb1
6b7a5e8
f57e4d6
c0f7fef
1c6f8a2
9aad16f
73f7195
285c866
deebf0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou | |
export IndependentMOKernel, | ||
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel | ||
|
||
export DiffPt | ||
|
||
# Reexports | ||
export tensor, ⊗, compose | ||
|
||
|
@@ -125,6 +127,7 @@ include("chainrules.jl") | |
include("zygoterules.jl") | ||
|
||
include("TestUtils.jl") | ||
include("diffKernel.jl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kernels are contained in the kernel subfolder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well it is not a new kernel - it is extending the functionality of all kernels which is why I did not put it here. But maybe it makes sense to put it there anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be a new kernel. That's how you can nicely fit it into the KernelFunctions ecosystem. |
||
|
||
function __init__() | ||
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
using OneHotArrays: OneHotVector | ||
import ForwardDiff as FD | ||
import LinearAlgebra as LA | ||
|
||
""" | ||
DiffPt(x; partial=()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is a separate type needed? Wouldn't it be better to use the existing input formats for multi-output kernels? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where can I find the existing input formats? https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/design/#inputs_for_multiple_outputs here it says that you would use I mean even if we are gracious and start counting with zero to make this less of a mess:
So the user would have to do this conversion from carthesian coordinates to linear coordinates. And then the implementation would revert this transformation back from linear coordinates to cartesian coordinates. Cartesian coordinates with ex ante unknown tuple length that is. That all seems annoying without being really necessary. But sure - you could do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you just need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is X? (btw mathjax is enabled: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this is a bit tricky So So more generally In essence a variable length tuple carries more information than just the longest tuple. Plus there is no longest tuple. |
||
|
||
For a covariance kernel k of GP Z, i.e. | ||
```julia | ||
k(x,y) # = Cov(Z(x), Z(y)), | ||
``` | ||
a DiffPt allows the differentiation of Z, i.e. | ||
```julia | ||
k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y)) | ||
``` | ||
for higher order derivatives partial can be any iterable, i.e. | ||
```julia | ||
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) | ||
``` | ||
""" | ||
struct DiffPt{Dim} | ||
FelixBenning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pos # the actual position | ||
partial | ||
end | ||
|
||
DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor | ||
|
||
""" | ||
Take the partial derivative of a function `fun` with input dimesion `dim`. | ||
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned. | ||
""" | ||
function partial(fun, dim, partials=()) | ||
FelixBenning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if !isnothing(local next = iterate(partials)) | ||
idx, state = next | ||
return partial( | ||
x -> FD.derivative(0) do dx | ||
fun(x .+ dx * OneHotVector(idx, dim)) | ||
end, | ||
dim, | ||
Base.rest(partials, state), | ||
) | ||
end | ||
return fun | ||
end | ||
|
||
""" | ||
Take the partial derivative of a function with two dim-dimensional inputs, | ||
i.e. 2*dim dimensional input | ||
""" | ||
function partial(k, dim; partials_x=(), partials_y=()) | ||
local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x) | ||
return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y) | ||
end | ||
|
||
|
||
|
||
|
||
""" | ||
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} | ||
|
||
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since | ||
generics are not allowed in the syntax above by the dispatch system, this | ||
redirection over `_evaluate` is necessary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my other comment, I think this should not be done and a simple wrapper would be sufficient. |
||
|
||
unboxes the partial instructions from DiffPt and applies them to k, | ||
evaluates them at the positions of DiffPt | ||
""" | ||
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} | ||
return partial( | ||
k, Dim, | ||
partials_x=x.partial, partials_y=y.partial | ||
)(x.pos, y.pos) | ||
end | ||
|
||
|
||
|
||
#= | ||
This is a hack to work around the fact that the `where {T<:Kernel}` clause is | ||
not allowed for the `(::T)(x,y)` syntax. If we were to only implement | ||
```julia | ||
(::Kernel)(::DiffPt,::DiffPt) | ||
``` | ||
then julia would not know whether to use | ||
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` | ||
``` | ||
Comment on lines
+88
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid this hack, no kernel type T should implement (::T)(x,y) and instead implement _evaluate(k::T, x, y) Then there should be only a single (k::Kernel)(x,y) = _evaluate(k, x, y) which all the kernels would fall back to. This ensures that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but this is a much more intrusive change so to not blow up the lines changed, I did not do this yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is also why "detect ambiguities" fails right now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should switch to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ugly thing about wrapping existing kernels is, that you essentially say: This is a different gaussian process model. I.e. f = GP(MaternKernel()) is a different (non-differentiable) model from g = GP(DiffWrapper(MaternKernel())), which is differentiable. But fundamentally the matern kernel implies that the gaussian process should always be And with this abstraction it is. I.e. you use x = 1:10
fx = f(x)
y = rand(fx) to simulate x = [DiffPt(0, partial=1), 1:10... ]
fx = f(x)
y0_grad, y... = rand(fx) and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also note how this abstraction lets you mix and match normal points and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You can perform exactly the same computations with a wrapper type.
I don't view it this way. The wrapper does not change the mathematical model, it just allows you to query derivatives as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
well, if the wrapped kernel has a superset of the capabilities of the original kernel (without performance cost) - why would you ever use the unwrapped kernel? So if you only ever use the wrapped kernel, then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it's not superfluous as the wrapped kernel has a different API (i.e., you have to provide different types/structure of inputs) that is more inconvenient to work with if you're not interested in evaluating those. There's a difference - but IMO it's not a mathematical one but rather regarding user experience. As you've already noticed I think it's also just not feasible to extend every implementation of differentiable kernels out there without making implementations of kernels more inconvenient for people that do not want to work with derivatives. So clearly separating these use cases seems simpler to me from a design perspective. The wrapper would be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no - you do not have to provide different types/structures of inputs. So it does not make it more inconvenient for people not interested in gradients. That is the entire point of specializing on |
||
To avoid this hack, no kernel type T should implement | ||
```julia | ||
(::T)(x,y) | ||
``` | ||
and instead implement | ||
```julia | ||
_evaluate(k::T, x, y) | ||
``` | ||
Then there should be only a single | ||
```julia | ||
(k::Kernel)(x,y) = _evaluate(k, x, y) | ||
``` | ||
which all the kernels would fall back to. | ||
|
||
This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always | ||
more specialized and call beforehand. | ||
=# | ||
for T in [SimpleKernel, Kernel] #subtypes(Kernel) | ||
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) | ||
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) | ||
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
@testset "diffKernel" begin | ||
@testset "smoke test" begin | ||
k = MaternKernel() | ||
k(1,1) | ||
k(1, DiffPt(1, partial=(1,1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 | ||
k(DiffPt([1], partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] | ||
k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] | ||
end | ||
|
||
@testset "Sanity Checks with $k" for k in [SEKernel()] | ||
for x in [0, 1, -1, 42] | ||
# for stationary kernels Cov(∂Z(x) , Z(x)) = 0 | ||
@test k(DiffPt(x, partial=1), x) ≈ 0 | ||
|
||
# the slope should be positively correlated with a point further down | ||
@test k( | ||
DiffPt(x, partial=1), # slope | ||
x + 1e-1 # point further down | ||
) > 0 | ||
|
||
# correlation with self should be positive | ||
@test k(DiffPt(x, partial=1), DiffPt(x, partial=1)) > 0 | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't want a dependency on ForwardDiff. We tried hard to avoid it so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah maybe this should be a plugin somehow, but writing this as a plugin would have caused a bunch of boilerplate to review and I wanted to make it easier to grasp the core idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe AbstractDifferentiaton.jl would be the right abstraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's not ready for proper use. But hopefully it will at some point.