From ddc8994a15f1e92118ac95a6c0a30b6ed918a6d4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 2 Jun 2023 14:54:30 +0200 Subject: [PATCH 1/4] Add LazyKernelMatrix and lazykernelmatrix --- src/KernelFunctions.jl | 2 + src/matrix/lazykernelmatrix.jl | 109 +++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 src/matrix/lazykernelmatrix.jl diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..56725313a 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -1,6 +1,7 @@ module KernelFunctions export kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag! +export LazyKernelMatrix, lazykernelmatrix export duplicate, set! # Helpers export Kernel, MOKernel @@ -106,6 +107,7 @@ include("kernels/gibbskernel.jl") include("kernels/scaledkernel.jl") include("kernels/normalizedkernel.jl") include("matrix/kernelmatrix.jl") +include("matrix/lazykernelmatrix.jl") include("kernels/kernelsum.jl") include("kernels/kernelproduct.jl") include("kernels/kerneltensorproduct.jl") diff --git a/src/matrix/lazykernelmatrix.jl b/src/matrix/lazykernelmatrix.jl new file mode 100644 index 000000000..f20c96aaf --- /dev/null +++ b/src/matrix/lazykernelmatrix.jl @@ -0,0 +1,109 @@ +""" + lazykernelmatrix(κ::Kernel, x::AbstractVector) -> AbstractMatrix + +Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`. + +The result is a matrix with the same entries as [`kernelmatrix(κ, x)`](@ref) but where the +entries are not computed until they are needed. +""" +lazykernelmatrix(κ::Kernel, x) = lazykernelmatrix(κ, x, x) + +""" + lazykernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) -> AbstractMatrix + +Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`. + +The result is a matrix with the same entries as [`kernelmatrix(κ, x, y)`](@ref) but where +the entries are not computed until they are needed. +""" +lazykernelmatrix(κ::Kernel, x, y) = LazyKernelMatrix(κ, x, y) + +""" + LazyKernelMatrix(κ::Kernel, x[, y]) + LazyKernelMatrix{T<:Real}(κ::Kernel, x, y) + +Construct a lazy representation of the kernel `κ` for each pair of inputs in `x` and `y`. + +Instead of constructing this directly, it is better to call +[`lazykernelmatrix(κ, x[, y])`](@ref lazykernelmatrix). +""" +struct LazyKernelMatrix{T<:Real,Tk<:Kernel,Tx<:AbstractVector,Ty<:AbstractVector} <: + AbstractMatrix{T} + kernel::Tk + x::Tx + y::Ty + function LazyKernelMatrix{T}(κ::Tk, x::Tx, y::Ty) where {T<:Real,Tk<:Kernel,Tx,Ty} + Base.require_one_based_indexing(x) + Base.require_one_based_indexing(y) + return new{T,Tk,Tx,Ty}(κ, x, y) + end + function LazyKernelMatrix{T}(κ::Tk, x::Tx) where {T<:Real,Tk<:Kernel,Tx} + Base.require_one_based_indexing(x) + return new{T,Tk,Tx,Tx}(κ, x, x) + end +end +function LazyKernelMatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) + # evaluate once to get eltype + T = typeof(κ(first(x), first(y))) + return LazyKernelMatrix{T}(κ, x, y) +end +LazyKernelMatrix(κ::Kernel, x::AbstractVector) = LazyKernelMatrix(κ, x, x) + +Base.Matrix(K::LazyKernelMatrix) = kernelmatrix(K.kernel, K.x, K.y) +function Base.AbstractMatrix{T}(K::LazyKernelMatrix) where {T} + return LazyKernelMatrix{T}(K.kernel, K.x, K.y) +end + +Base.size(K::LazyKernelMatrix) = (length(K.x), length(K.y)) + +Base.axes(K::LazyKernelMatrix) = (axes(K.x, 1), axes(K.y, 1)) + +function Base.getindex(K::LazyKernelMatrix{T}, i::Int, j::Int) where {T} + return T(K.kernel(K.x[i], K.y[j])) +end +for f in (:getindex, :view) + @eval begin + function Base.$f( + K::LazyKernelMatrix{T}, + I::Union{Colon,AbstractVector}, + J::Union{Colon,AbstractVector}, + ) where {T} + return LazyKernelMatrix{T}(K.kernel, $f(K.x, I), $f(K.y, J)) + end + end +end + +Base.zero(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(ZeroKernel(), K.x, K.y) +Base.one(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(WhiteKernel(), K.x, K.y) + +function Base.:*(c::S, K::LazyKernelMatrix{T}) where {T,S<:Real} + R = typeof(oneunit(S) * oneunit(T)) + return LazyKernelMatrix{R}(c * K.kernel, K.x, K.y) +end +Base.:*(K::LazyKernelMatrix, c::Real) = c * K +Base.:/(K::LazyKernelMatrix, c::Real) = K * inv(c) +Base.:\(c::Real, K::LazyKernelMatrix) = inv(c) * K + +function Base.:+(K::LazyKernelMatrix{T}, C::UniformScaling{S}) where {T,S<:Real} + if isequal(K.x, K.y) + R = typeof(zero(T) + zero(S)) + return LazyKernelMatrix{R}(K.kernel + C.λ * WhiteKernel(), K.x, K.y) + else + return Matrix(K) + C + end +end +function Base.:+(C::UniformScaling{S}, K::LazyKernelMatrix{T}) where {T,S<:Real} + if isequal(K.x, K.y) + R = typeof(zero(T) + zero(S)) + return LazyKernelMatrix{R}(C.λ * WhiteKernel() + K.kernel, K.x, K.y) + else + return C + Matrix(K) + end +end +function Base.:+(K1::LazyKernelMatrix, K2::LazyKernelMatrix) + if isequal(K1.x, K2.x) && isequal(K1.y, K2.y) + return LazyKernelMatrix(K1.kernel + K2.kernel, K1.x, K1.y) + else + return Matrix(K1) + Matrix(K2) + end +end From 70ddd0355a66c6f8bda3b2944b682f73a8460f98 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 2 Jun 2023 14:54:39 +0200 Subject: [PATCH 2/4] Update API docs --- docs/src/api.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9fb241fa4..1def1fcdb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -6,7 +6,9 @@ CurrentModule = KernelFunctions ## Functions -The KernelFunctions API comprises the following four functions. +The KernelFunctions API comprises the following functions. + +The first set eagerly construct all or part of a kernel matrix ```@docs kernelmatrix kernelmatrix! @@ -14,6 +16,12 @@ kernelmatrix_diag kernelmatrix_diag! ``` +It is also possible to lazily construct the same matrix +```@docs +lazykernelmatrix +LazyKernelMatrix +``` + ## Input Types The above API operates on collections of inputs. From 32f17d691d4ecc2c619afa1610213e3af1a3aff6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 2 Jun 2023 14:55:27 +0200 Subject: [PATCH 3/4] Update kernelmatrix docstring --- src/matrix/kernelmatrix.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index e778f79eb..77ee9c007 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -30,12 +30,16 @@ Compute the kernel `κ` for each pair of inputs in `x`. Returns a matrix of size `(length(x), length(x))` satisfying `kernelmatrix(κ, x)[p, q] == κ(x[p], x[q])`. +If `x` is large, consider using [`lazykernelmatrix`](@ref) instead. + kernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) Compute the kernel `κ` for each pair of inputs in `x` and `y`. Returns a matrix of size `(length(x), length(y))` satisfying `kernelmatrix(κ, x, y)[p, q] == κ(x[p], y[q])`. +If `x` and `y` are large, consider using [`lazykernelmatrix`](@ref) instead. + kernelmatrix(κ::Kernel, X::AbstractMatrix; obsdim) kernelmatrix(κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim) From 4d7d0b2fa6e51e7613fcecc319efeacbbb47d929 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 2 Jun 2023 15:01:22 +0200 Subject: [PATCH 4/4] Update more docs --- docs/src/api.md | 2 +- docs/src/create_kernel.md | 2 +- docs/src/userguide.md | 9 ++++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 1def1fcdb..13042cab0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -16,7 +16,7 @@ kernelmatrix_diag kernelmatrix_diag! ``` -It is also possible to lazily construct the same matrix +It is also possible to lazily construct the same matrix, which is recommended when the kernel matrix might be too large to store in memory ```@docs lazykernelmatrix LazyKernelMatrix diff --git a/docs/src/create_kernel.md b/docs/src/create_kernel.md index c7234b41f..870792374 100644 --- a/docs/src/create_kernel.md +++ b/docs/src/create_kernel.md @@ -38,7 +38,7 @@ Finally there are additional functions you can define to bring in more features: - `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods. - `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default. - `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types. - - `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations. + - `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` and `lazykernelmatrix` functions to eventually optimize the computations. - `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel. KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters diff --git a/docs/src/userguide.md b/docs/src/userguide.md index d3e8789f7..269c979cc 100644 --- a/docs/src/userguide.md +++ b/docs/src/userguide.md @@ -61,7 +61,7 @@ k(x1, x2) ## Creating a Kernel Matrix -Kernel matrices can be created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal. +Kernel matrices can be eagerly created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal. For example, for a collection of 10 `Real`-valued inputs: ```julia k = SqExponentialKernel() @@ -90,6 +90,13 @@ kernelmatrix(k, X; obsdim=2) # same as ColVecs(X) ``` This is similar to the convention used in [Distances.jl](https://github.com/JuliaStats/Distances.jl). +When data is large, it may not be possible to store the kernel matrix in memory. +Then it is recommended to use `lazykernelmatrix`: +```julia +lazykernelmatrix(k, RowVecs(X)) +lazykernelmatrix(k, ColVecs(X)) +``` + ### So what type should I use to represent a collection of inputs? The central assumption made by KernelFunctions.jl is that all collections of `N` inputs are represented by `AbstractVector`s of length `N`. Abstraction is then used to ensure that efficiency is retained, `ColVecs` and `RowVecs`