Skip to content

Commit

Permalink
Improved type stability with explicit params
Browse files Browse the repository at this point in the history
We can disable accumulating (implicit) parameters to the gradient cache
in explicit mode. This can dramatically improve type stability because
`accum_param` will return a `Union{Nothing, [grad type]}` otherwise.
  • Loading branch information
ToucheSir committed Jun 24, 2022
1 parent b29d5b2 commit 7540bd6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 21 deletions.
34 changes: 26 additions & 8 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ using Core: Typeof
import Base: copy!, IdSet
import Base.Broadcast: broadcasted, materialize!

mutable struct Context <: AContext
mutable struct Context{I} <: AContext
cache::Union{IdDict{Any,Any},Nothing}
end

Context() = Context(nothing)
Context() = Context{false}(nothing)

cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache

Expand Down Expand Up @@ -36,10 +36,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)

function pullback(f, args...)
y, back = _pullback(f, args...)
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
y, Δ -> tailmemaybe(back(Δ))
end
function pullback(cx::Context, f, args...)
ChainRulesCore.ignore_derivatives() do
@warn """
Incorrect argument order for pullback, please use:
pullback(f, __context__::Context, args)
instead of:
pullback(__context__::Context, f, args)
This is usually caused by a call to pullback in a higher-order @adjoint.
The above warning will become an error in Zygote 0.7.
"""
end
return pullback(f, cx, args...)
end

sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
Expand Down Expand Up @@ -334,21 +352,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
end

function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
throw(ArgumentError("map! expects Grads objects with the same Params."))
for p in gsout.params
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end

function _getformap(gs, p)
g = gs[p]
isnothing(g) ? fill!(similar(p), 0) : g
isnothing(g) ? fill!(similar(p), 0) : g
end

function pullback(f, ps::Params)
cx = Context()
cx = Context{true}(nothing)
y, back = _pullback(cx, f)
y, function (Δ)
for p in ps
Expand Down
4 changes: 2 additions & 2 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ end

@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
Expand All @@ -315,7 +315,7 @@ end


function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
y, ȳ -> (nothing, back(ȳ)...)
end

Expand Down
12 changes: 8 additions & 4 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
# Utilities
# =========

# ChainRules already marks this non-differentiable,
# But inference can still give up because of the Zygote -> CR wrapper layer
@nograd Broadcast.combine_styles

accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)

# Work around reducedim_init issue
Expand Down Expand Up @@ -82,16 +86,16 @@ _minus(::Nothing) = nothing
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
_pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
_pullback(*, x, y)
_pullback(__context__, *, x, y)

@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
_pullback(/, x, y)
_pullback(__context__, /, x, y)

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
Expand Down Expand Up @@ -284,7 +288,7 @@ end
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
Expand Down
3 changes: 2 additions & 1 deletion src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ accum(x, y) =

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)

@generated function accum(x::NamedTuple, y::NamedTuple)
Expand Down Expand Up @@ -50,6 +50,7 @@ end

@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
isbitstype(x) && return :(Δ)
quote
Expand Down
17 changes: 11 additions & 6 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Zygote, Test
using Zygote: pullback, @adjoint
using Zygote: pullback, @adjoint, Context

macro test_inferred(ex)
:(let res = nothing
Expand Down Expand Up @@ -160,13 +160,18 @@ end
@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
y_explicit, back_explicit = @inferred pullback(x -> x.m, g)
y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g)
@test y_explicit == y_implicit == getfield(g, :m)

∇args = ((m = [1.0, 0.0, 0.0], P = nothing),)
if VERSION > v"1.7-"
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
# This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}]
# But the same should infer if implicit parameters are disabled
@test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)]
end
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
@test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
Expand Down

0 comments on commit 7540bd6

Please sign in to comment.