From 942eeade0e1d5a9b092f8a9928f8c4807a699487 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sun, 29 Sep 2024 21:26:08 +0200 Subject: [PATCH] Handle varargs correctly in value_and_pullback!! and value_and_gradient!! (#277) * Fix __verify_sig for vararg calls * Bump patch version * Bump TemporalGPs test dep version --- Project.toml | 4 ++-- src/interface.jl | 7 ++++--- test/interface.jl | 1 + 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 40b29d606..23eb79576 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.4" +version = "0.4.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -51,7 +51,7 @@ PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" -TemporalGPs = "0.6" +TemporalGPs = "0.7" julia = "1.10" [extras] diff --git a/src/interface.jl b/src/interface.jl index 89ee34466..56c02e362 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -19,10 +19,11 @@ function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R end function __verify_sig( - ::DerivedRule{<:Any, <:MistyClosure{<:OpaqueClosure{sig}}}, ::Tfx + rule::DerivedRule{<:Any, <:MistyClosure{<:OpaqueClosure{sig}}}, fx::Tfx ) where {sig, Tfx} - if sig != Tfx - msg = "signature of arguments, $Tfx, not equal to signature required by rule, $sig." + Pfx = typeof(__unflatten_codual_varargs(rule.isva, fx, rule.nargs)) + if sig != Pfx + msg = "signature of arguments, $Pfx, not equal to signature required by rule, $sig." throw(ArgumentError(msg)) end end diff --git a/test/interface.jl b/test/interface.jl index c6a583cf2..8f3ef4ecf 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -27,6 +27,7 @@ (x -> sin(cos(x)), randn(Float32)), ((x, y) -> x + sin(y), randn(Float64), randn(Float64)), ((x, y) -> x + sin(y), randn(Float32), randn(Float32)), + ((x...) -> x[1] + x[2], randn(Float64), randn(Float64)), ] rule = build_rrule(fargs...) f, args... = fargs