From cf937ceed9b2666a812cf1ec7350ac6c5a231d00 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Sep 2023 17:07:33 +0100 Subject: [PATCH] Make pullback error for ColVecs and RowVecs a bit more informative (#523) * make adjoint error message a bit more informative * Apply suggestions from code review * Update Project.toml * Update src/chainrules.jl --- Project.toml | 2 +- src/chainrules.jl | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 617f6e56c..2830de211 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.56" +version = "0.10.57" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index 3b52860dd..eebdf95b5 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some external computation has acted on `ColVecs` to produce a vector of vectors." * + "In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." * + "In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," * + " rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it." ) end return ColVecs(X), ColVecs_pullback @@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some external computation has acted on `RowVecs` to produce a vector of vectors." * + "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", ) end return RowVecs(X), RowVecs_pullback