-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce DiffPt for the covariance function of derivatives #508
base: master
Are you sure you want to change the base?
Introduce DiffPt for the covariance function of derivatives #508
Conversation
This is a hack to work around the fact that the `where {T<:Kernel}` clause is | ||
not allowed for the `(::T)(x,y)` syntax. If we were to only implement | ||
```julia | ||
(::Kernel)(::DiffPt,::DiffPt) | ||
``` | ||
then julia would not know whether to use | ||
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid this hack, no kernel type T should implement
(::T)(x,y)
and instead implement
_evaluate(k::T, x, y)
Then there should be only a single
(k::Kernel)(x,y) = _evaluate(k, x, y)
which all the kernels would fall back to.
This ensures that _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where T<:Kernel
is always
more specialized and called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but this is a much more intrusive change so to not blow up the lines changed, I did not do this yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is also why "detect ambiguities" fails right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should switch to _evaluate(k, x, y)
. Instead of changing all other kernels, I think you should just create a single wrapper of existing kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ugly thing about wrapping existing kernels is, that you essentially say: This is a different gaussian process model.
I.e.
f = GP(MaternKernel())
is a different (non-differentiable) model from
g = GP(DiffWrapper(MaternKernel())),
which is differentiable. But fundamentally the matern kernel implies that the gaussian process should always be f
ought to be differentiable.
And with this abstraction it is. I.e. you use
x = 1:10
fx = f(x)
y = rand(fx)
to simulate f
at points 1:10
. If you now wanted to simulate its gradient at point 0
too, you would just have to modify its input
x = [DiffPt(0, partial=1), 1:10... ]
fx = f(x)
y0_grad, y... = rand(fx)
and y0_grad
would be the gradient in 0
as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note how this abstraction lets you mix and match normal points and DiffPt
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note how this abstraction lets you mix and match normal points and DiffPts
You can perform exactly the same computations with a wrapper type.
The ugly thing about wrapping existing kernels is, that you essentially say: This is a different gaussian process model.
I don't view it this way. The wrapper does not change the mathematical model, it just allows you to query derivatives as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can perform exactly the same computations with a wrapper type.
well, if the wrapped kernel has a superset of the capabilities of the original kernel (without performance cost) - why would you ever use the unwrapped kernel? So if you only ever use the wrapped kernel, then
GP(DiffWrapper(MaternKernel())
is kind of superflous. So then you would probably start to write convenience functions to get the wrapped kernel immediately, but for the wrapped kernel the compositions like +
, ... are not implemented. So you would have to pass all those through. It seems like a pointless effort to get a capability, which the original kernel should already have. I mean the capability does not collide with anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not superfluous as the wrapped kernel has a different API (i.e., you have to provide different types/structure of inputs) that is more inconvenient to work with if you're not interested in evaluating those. There's a difference - but IMO it's not a mathematical one but rather regarding user experience.
As you've already noticed I think it's also just not feasible to extend every implementation of differentiable kernels out there without making implementations of kernels more inconvenient for people that do not want to work with derivatives. So clearly separating these use cases seems simpler to me from a design perspective.
The wrapper would be a Kernel
as well, of course, so if the existing implementations in KernelFunctions such as sums etc. are written as generally as they were intended no new definitions for compositions are needed. You would add them only if there is a clear performance gain that outweighs code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no - you do not have to provide different types/structures of inputs. So it does not make it more inconvenient for people not interested in gradients. That is the entire point of specializing on DiffPt
the kernel function still works on everything that is not a DiffPt
Codecov ReportPatch coverage has no change and project coverage change:
Additional details and impacted files@@ Coverage Diff @@
## master #508 +/- ##
===========================================
- Coverage 94.16% 77.41% -16.75%
===========================================
Files 52 54 +2
Lines 1387 1430 +43
===========================================
- Hits 1306 1107 -199
- Misses 81 323 +242
☔ View full report in Codecov by Sentry. |
I don't think |
@@ -8,10 +8,12 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | |||
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" | |||
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | |||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | |||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't want a dependency on ForwardDiff. We tried hard to avoid it so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah maybe this should be a plugin somehow, but writing this as a plugin would have caused a bunch of boilerplate to review and I wanted to make it easier to grasp the core idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe AbstractDifferentiaton.jl would be the right abstraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's not ready for proper use. But hopefully it will at some point.
@@ -125,6 +127,7 @@ include("chainrules.jl") | |||
include("zygoterules.jl") | |||
|
|||
include("TestUtils.jl") | |||
include("diffKernel.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernels are contained in the kernel subfolder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it is not a new kernel - it is extending the functionality of all kernels which is why I did not put it here. But maybe it makes sense to put it there anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be a new kernel. That's how you can nicely fit it into the KernelFunctions ecosystem.
import LinearAlgebra as LA | ||
|
||
""" | ||
DiffPt(x; partial=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is a separate type needed? Wouldn't it be better to use the existing input formats for multi-output kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where can I find the existing input formats? https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/design/#inputs_for_multiple_outputs here it says that you would use Tuple{T, Int}
and while you could in principle convert a partial
tuple to an Int
, you reall don't want to.
I mean even if we are gracious and start counting with zero to make this less of a mess:
() -> 0 # no derivative
(1) -> 1 # partial derivative in direction 1
(2) -> 2
...
(dim) -> dim # partial derivative in direction dim
(1,1) -> dim+1 # twice partial derivative in direction 1
(1,2) -> dim +2
...
(k, j) -> k * dim + j
So the user would have to do this conversion from carthesian coordinates to linear coordinates. And then the implementation would revert this transformation back from linear coordinates to cartesian coordinates. Cartesian coordinates with ex ante unknown tuple length that is.
That all seems annoying without being really necessary. But sure - you could do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you just need X \times \{0, \ldots, dim\}
to represent all desired partials?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is X? (btw mathjax is enabled: $
works for inline and ```math
works for multiline maths.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this is a bit tricky
So
So more generally
In essence a variable length tuple carries more information than just the longest tuple. Plus there is no longest tuple.
This is a hack to work around the fact that the `where {T<:Kernel}` clause is | ||
not allowed for the `(::T)(x,y)` syntax. If we were to only implement | ||
```julia | ||
(::Kernel)(::DiffPt,::DiffPt) | ||
``` | ||
then julia would not know whether to use | ||
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should switch to _evaluate(k, x, y)
. Instead of changing all other kernels, I think you should just create a single wrapper of existing kernels.
|
||
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since | ||
generics are not allowed in the syntax above by the dispatch system, this | ||
redirection over `_evaluate` is necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my other comment, I think this should not be done and a simple wrapper would be sufficient.
src/diffKernel.jl
Outdated
i.e. 2*dim dimensional input | ||
""" | ||
function partial(k, dim; partials_x=(), partials_y=()) | ||
local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local
is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess - felt right to explicitly say that this function is only temporary and is immediately going to be transformed again
I have been playing around with the ideas in PR, and realized that to make this work there are some open questions. The first issue is that At this point a wrapper might be easier, because then the only thing needed are some additional methods. At this point Additionally, GP |
kernelmatrix@Crown421 the function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end is more performant than function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return broadcast(κ, x, permutedims(x)) I mean K = broadcast(x, permutedims(x) do (x,y)
metric(κ)(x,y)
end # first pass over K
map( x->kappa(κ, x), K) # second pass over K which accesses the elements of K = broadcast(x, permutedims(x) do (x,y)
κ(x,y)
end only requires one access. Since memory access is typically the bottleneck, the general definition should be more performant. That is unless (κ::SimpleKernel)(x,y) = kappa(κ, metric(κ)(x,y)) is not inlined and causes more function calls. But in that case it probably makes more sense to force inline the code above with K = broadcast(x, permutedims(x) do (x,y)
kappa(κ, metric(κ)(x,y))
end by the compiler. Which should be faster than the two pass version. Issues with a wrapperWhile So DiffWrapper(kernel) ∘ FunctionTransform(f) != DiffWrapper(kernel ∘ FunctionTransform(f)) So if you wanted to treat which is differentiable, then you would expect the behavior of Caching the cholesky decompositionI do not understand this point. This would automatically happen with this implementation too. I mean |
This is really weird... julia> @btime kernelmatrix(k, Xc);
23.836 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa(k, x), KernelFunctions.pairwise(KernelFunctions.metric(k), Xc));
103.780 ms (8000009 allocations: 183.12 MiB)
julia> @btime k.(Xc, permutedims(Xc));
78.818 ms (4 allocations: 30.52 MiB)
julia> size(Xc)
(2000,)
julia> size(first(Xc))
(2,)
julia> k
Squared Exponential Kernel (metric = Distances.Euclidean(0.0)) Why is the performance of the implementation function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return map(x -> kappa(κ, x), pairwise(metric(κ), x))
end https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/matrix/kernelmatrix.jl#L149-L151 |
Not sure what function you mean here and what you expect to be worse/better. The main issue is that your benchmarking is flawed, variables etc. have to be interpolated since otherwise you suffer, sometimes massively, from type instabilities and inference issues introduced by global variables. So instead you should perform benchmarks such as julia> using KernelFunctions, BenchmarkTools
julia> Xc = ColVecs(randn(2, 2000));
julia> k = GaussianKernel();
julia> @btime kernelmatrix($k, $Xc);
38.594 ms (5 allocations: 61.05 MiB)
julia> @btime kernelmatrix($k, $Xc);
35.585 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
37.478 ms (5 allocations: 61.05 MiB)
julia> @btime map(x -> KernelFunctions.kappa($k, x), KernelFunctions.pairwise(KernelFunctions.metric($k), $Xc));
33.321 ms (5 allocations: 61.05 MiB)
julia> @btime $k.($Xc, permutedims($Xc));
45.019 ms (2 allocations: 30.52 MiB)
julia> @btime $k.($Xc, permutedims($Xc));
45.339 ms (2 allocations: 30.52 MiB)
No, not generally.
Therefore this statement also does not hold generally. If you are concerned about memory allocations, probably you also might want to use julia> # Continued from above
julia> K = Matrix{Float64}(undef, length(Xc), length(Xc));
julia> @btime kernelmatrix!($K, $k, $Xc);
25.775 ms (1 allocation: 15.75 KiB)
julia> @btime kernelmatrix!($K, $k, $Xc);
30.012 ms (1 allocation: 15.75 KiB) Another disadvantage of broadcasting is that generally it means more work for the compiler (the whole broadcasting machinery is very involved and complicated) and hence increases compilation times. |
@devmotion ahh 🤦 I only looked into the I guess if pairwise actually uses the symmetry of distances, then I see where the speedup in the isotropic case comes from. |
Yes, I try to avoid |
My apologies, I had an error in thinking here, I was convinced that an additional matrix would need to be cached, not sure why.
Well, not zero reason. There are multiple reasons for using a wrapper in this PR, and therefore it comes down to opinion. It would be easy to define some fallback functions that throw an error in problematic cases, advising users to use the wrapper at the end. Given that differentiable kernels would not be a core feature, but rather an Extension when also loading a compatible autodiff package, any changes in main part of KernelFunctions should be minimal, and not reduce any performance. Therefore I would personally prefer starting with a wrapper, at least for now, to have the key functionality available and see additional issues during use. For example, I have already wondered:
For me these are important usability questions, with a much higher "ugliness" potential than where one can put a wrapper. During a normal session, I manipulate a lot of inputs and input collections, but only define a GP/ kernel once. |
I am starting to agree, given that I can not come up with a good solution to the kernelmatrix problem at the moment.
That is something I am currently thinking about a lot. I would think that custom composite types would be a good idea. Storing (x, 2)
(x, n),
(y,1),
...
(y,n) could be replaced and emulated by some sort of dictionary x => [2, n]
y => 1:n the advantage is, that you could specialize on index ranges to take more than one partial derivative (and use backwarddiff to get the entire gradient). But you would still need the ability to interleave points (x,1)
(y,2)
(x,2) and I am not yet sure how to fix the abstract order of the points. Basically what should probably happen is something akin to an SQL join: TABLE: Enries
TABLE: Positions
A left join on (Entries, Postions) would then result in the theoretical list [
DiffPt(pos1, ()),
DiffPt(pos1, (1,2)),
DiffPt(pos2, (2,)),
...
] But now I don't have the the ranges yet... |
Summary
This is a minimal implementation to enable the simulation of gradients (and higher order derivatives) of GPs (see also #504)
Proposed changes
For a covariance kernel k of GP Z, i.e.
a
DiffPt
allows the differentiation of Z, i.e.for higher order derivatives partial can be any iterable, i.e.
the code for this feature is extremely minimal but allows the simulation of arbitrary derivatives of Gaussian Processes. It only contains
DiffPt
_evaluate(::T, x::DiffPt, y::DiffPt) where {T<: Kernel}
function which callspartial
functions that take the derivatives.What alternatives have you considered?
This is the implementation with the smallest footprint but not the most performant. What essentially happens here is the simulation of the multivariate GP$f = (Z, \nabla Z)$ which is a $d+1$ dimensional GP if $Z$ is a univariate GP with input dimension $d$ . Due to the "no multi-variate kernels" design philosophy of KernelFunctions.jl we are forced to calculate the entries of the covariance matrix one-by-one. It would be more performant to calculate the entire matrix in one go using backward diff for the first pass and forward diff for the second derivative.
It might be possible to somehow specialize on ranges to get back this performance. But it is not completely clear how. Since we do not call
which could easily be caught by specializing on
broadcast
but in reality we do something likeAnd this is still not quite true as we consider all pairs of these lists and not just a zip.
Breaking changes
None.