Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved type stability with explicit params #1248

Merged
merged 3 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ using Core: Typeof
import Base: copy!, IdSet
import Base.Broadcast: broadcasted, materialize!

mutable struct Context <: AContext
# Internal container used to track accumulated gradients of mutable types (including params).
# Type param I ∈ (true, false) indicates whether implicit params are in use.
# By default, this should be false unless pullback(f, ::Params) is called.
mutable struct Context{I} <: AContext
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
cache::Union{IdDict{Any,Any},Nothing}
end

Context() = Context(nothing)
Context() = Context{false}(nothing)

cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache

Expand Down Expand Up @@ -36,10 +39,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)

function pullback(f, args...)
y, back = _pullback(f, args...)
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
y, Δ -> tailmemaybe(back(Δ))
end
function pullback(cx::Context, f, args...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
ChainRulesCore.ignore_derivatives() do
@warn """
Incorrect argument order for pullback, please use:

pullback(f, __context__::Context, args)

instead of:

pullback(__context__::Context, f, args)

This is usually caused by a call to pullback in a higher-order @adjoint.
The above warning will become an error in Zygote 0.7.
"""
end
return pullback(f, cx, args...)
end

sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
Expand Down Expand Up @@ -334,21 +355,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
end

function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
throw(ArgumentError("map! expects Grads objects with the same Params."))
for p in gsout.params
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end

function _getformap(gs, p)
g = gs[p]
isnothing(g) ? fill!(similar(p), 0) : g
isnothing(g) ? fill!(similar(p), 0) : g
end

function pullback(f, ps::Params)
cx = Context()
cx = Context{true}(nothing)
y, back = _pullback(cx, f)
y, function (Δ)
for p in ps
Expand Down
4 changes: 2 additions & 2 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,15 @@ end

@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
y, ȳ -> (nothing, back(ȳ)...)
end

Expand Down
12 changes: 8 additions & 4 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
# Utilities
# =========

# ChainRules already marks this non-differentiable,
# But inference can still give up because of the Zygote -> CR wrapper layer
@nograd Broadcast.combine_styles

accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)

# Work around reducedim_init issue
Expand Down Expand Up @@ -82,16 +86,16 @@ _minus(::Nothing) = nothing
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
_pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
_pullback(*, x, y)
_pullback(__context__, *, x, y)

@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
_pullback(/, x, y)
_pullback(__context__, /, x, y)

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
Expand Down Expand Up @@ -273,7 +277,7 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Expand Down
3 changes: 2 additions & 1 deletion src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ accum(x, y) =

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)

@generated function accum(x::NamedTuple, y::NamedTuple)
Expand All @@ -48,6 +48,7 @@ end

@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
isbitstype(x) && return :(Δ)
quote
Expand Down
17 changes: 11 additions & 6 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Zygote, Test
using Zygote: pullback, @adjoint
using Zygote: pullback, @adjoint, Context

macro test_inferred(ex)
:(let res = nothing
Expand Down Expand Up @@ -160,13 +160,18 @@ end
@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
y_explicit, back_explicit = @inferred pullback(x -> x.m, g)
y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g)
@test y_explicit == y_implicit == getfield(g, :m)

∇args = ((m = [1.0, 0.0, 0.0], P = nothing),)
if VERSION > v"1.7-"
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
# This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}]
# But the same should infer if implicit parameters are disabled
@test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)]
end
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
@test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
Expand Down