From 39201d990d4d65eabcccbcfd4f445059e5e1f025 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 8 May 2021 00:29:46 +0200 Subject: [PATCH 1/4] unthunk in adjoint --- src/adjoint.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index a5f3a5d..dd232f1 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -30,6 +30,12 @@ function adjoint end function _pullback end function pullback end + +function unthunk_tangent end +@inline unthunk_tangent(x) = x +@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) + + function gradm(ex, mut = false) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {Ts__} = body_)) || error("Need a function definition") @@ -56,13 +62,13 @@ function gradm(ex, mut = false) @inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...)) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuple(_back(Δ)) + back(Δ) = $gradtuple(_back(unthunk_tangent(Δ))) return y, back end @inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...); kw...) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuplekw(_back(Δ)) + back(Δ) = $gradtuplekw(_back(unthunk_tangent(Δ))) return y, back end nothing From 7c2b91014cc22072024f2244fe4e8e51eb22e710 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 8 May 2021 02:34:51 +0200 Subject: [PATCH 2/4] Add macro djoint_keepthunks --- src/ZygoteRules.jl | 2 +- src/adjoint.jl | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/ZygoteRules.jl b/src/ZygoteRules.jl index 9d20f88..0aaaccd 100644 --- a/src/ZygoteRules.jl +++ b/src/ZygoteRules.jl @@ -1,6 +1,6 @@ module ZygoteRules -export @adjoint, @adjoint! +export @adjoint, @adjoint!, @adjoint_keepthunks, @adjoint_keepthunks! """ ZygoteRules.literal_getproperty(x, ::Val{f}) diff --git a/src/adjoint.jl b/src/adjoint.jl index dd232f1..a40c6d1 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -36,7 +36,7 @@ function unthunk_tangent end @inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) -function gradm(ex, mut = false) +function gradm(ex, mut = false, keepthunks = false) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {Ts__} = body_)) || error("Need a function definition") kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing @@ -57,18 +57,19 @@ function gradm(ex, mut = false) gradtuple = isclosure ? gradtuple0 : gradtuple1 gradtuplekw = isclosure ? gradtuple2 : gradtuple3 adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body)) + maybe_unthunked_Δ = keepthunks ? :Δ : :(unthunk_tangent(Δ)) quote $adj @inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...)) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuple(_back(unthunk_tangent(Δ))) + back(Δ) = $gradtuple(_back($maybe_unthunked_Δ)) return y, back end @inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...); kw...) $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuplekw(_back(unthunk_tangent(Δ))) + back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ)) return y, back end nothing @@ -76,9 +77,17 @@ function gradm(ex, mut = false) end macro adjoint(ex) - gradm(ex) + gradm(ex, false, false) end macro adjoint!(ex) - gradm(ex, true) + gradm(ex, true, false) +end + +macro adjoint_keepthunks(ex) + gradm(ex, false, true) +end + +macro adjoint_keepthunks!(ex) + gradm(ex, true, true) end From b8b91a26376c4bb7b40397cf793f8aae363ed6ac Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 8 May 2021 02:39:56 +0200 Subject: [PATCH 3/4] Map unthunk_tangent over NamedTuples --- src/adjoint.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/adjoint.jl b/src/adjoint.jl index a40c6d1..dc10961 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -34,6 +34,7 @@ function pullback end function unthunk_tangent end @inline unthunk_tangent(x) = x @inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) +@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x) function gradm(ex, mut = false, keepthunks = false) From 05cd6e1d41a363b2114fcce2a640df145006a5b7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 11 May 2021 22:47:16 +0200 Subject: [PATCH 4/4] Remove macro adjoint_keepthunks Replaced my internal macro _adjoint_keepthunks in Zygote. --- src/ZygoteRules.jl | 2 +- src/adjoint.jl | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/ZygoteRules.jl b/src/ZygoteRules.jl index 0aaaccd..9d20f88 100644 --- a/src/ZygoteRules.jl +++ b/src/ZygoteRules.jl @@ -1,6 +1,6 @@ module ZygoteRules -export @adjoint, @adjoint!, @adjoint_keepthunks, @adjoint_keepthunks! +export @adjoint, @adjoint! """ ZygoteRules.literal_getproperty(x, ::Val{f}) diff --git a/src/adjoint.jl b/src/adjoint.jl index dc10961..47f628e 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -84,11 +84,3 @@ end macro adjoint!(ex) gradm(ex, true, false) end - -macro adjoint_keepthunks(ex) - gradm(ex, false, true) -end - -macro adjoint_keepthunks!(ex) - gradm(ex, true, true) -end