diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index d5428e97e..ee2a69528 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -4,11 +4,14 @@ using Core: Typeof import Base: copy!, IdSet import Base.Broadcast: broadcasted, materialize! -mutable struct Context <: AContext +# Internal container used to track accumulated gradients of mutable types (including params). +# Type param I ∈ (true, false) indicates whether implicit params are in use. +# By default, this should be false unless pullback(f, ::Params) is called. +mutable struct Context{I} <: AContext cache::Union{IdDict{Any,Any},Nothing} end -Context() = Context(nothing) +Context() = Context{false}(nothing) cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache @@ -36,10 +39,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing tailmemaybe(x::Tuple) = Base.tail(x) -function pullback(f, args...) - y, back = _pullback(f, args...) +@inline pullback(f, args...) = pullback(f, Context(), args...) +function pullback(f, cx::AContext, args...) + y, back = _pullback(cx, f, args...) y, Δ -> tailmemaybe(back(Δ)) end +function pullback(cx::Context, f, args...) + ChainRulesCore.ignore_derivatives() do + @warn """ + Incorrect argument order for pullback, please use: + + pullback(f, __context__::Context, args) + + instead of: + + pullback(__context__::Context, f, args) + + This is usually caused by a call to pullback in a higher-order @adjoint. + The above warning will become an error in Zygote 0.7. + """ + end + return pullback(f, cx, args...) +end sensitivity(y::Number) = one(y) sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.") @@ -334,21 +355,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...) end function Base.map!(f, gsout::Grads, gss::ADictOrGrads...) - all(issetequal(gsout.params, keys(gs)) for gs in gss) || + all(issetequal(gsout.params, keys(gs)) for gs in gss) || throw(ArgumentError("map! expects Grads objects with the same Params.")) for p in gsout.params - gsout[p] = f((_getformap(gs, p) for gs in gss)...) + gsout[p] = f((_getformap(gs, p) for gs in gss)...) end return gsout end function _getformap(gs, p) g = gs[p] - isnothing(g) ? fill!(similar(p), 0) : g + isnothing(g) ? fill!(similar(p), 0) : g end function pullback(f, ps::Params) - cx = Context() + cx = Context{true}(nothing) y, back = _pullback(cx, f) y, function (Δ) for p in ps diff --git a/src/lib/array.jl b/src/lib/array.jl index 4e72713d9..293801b21 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -310,7 +310,7 @@ end @adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...) @assert !haskey(kws, :init) # TODO add init support (julia 1.6) - return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) + return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs) end @adjoint function sum(xs::AbstractArray{Bool}; dims = :) @@ -318,7 +318,7 @@ end end function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray) - y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs) + y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs) y, ȳ -> (nothing, back(ȳ)...) end diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index b3c16e823..8c0d3c54c 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize # Utilities # ========= +# ChainRules already marks this non-differentiable, +# But inference can still give up because of the Zygote -> CR wrapper layer +@nograd Broadcast.combine_styles + accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims) # Work around reducedim_init issue @@ -82,16 +86,16 @@ _minus(::Nothing) = nothing @adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x))) @adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) = - _pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y)) + _pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y)) @adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) = - _pullback(*, x, y) + _pullback(__context__, *, x, y) @adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric) res = x ./ y res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y))) end @adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) = - _pullback(/, x, y) + _pullback(__context__, /, x, y) @adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p y = Base.literal_pow.(^, x, exp) @@ -273,7 +277,7 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible @adjoint function sum(f, xs::AbstractGPUArray; kws...) @assert !haskey(kws, :init) # TODO add init support (julia 1.6) - return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) + return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs) end @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray} diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 22bda1e19..52a734809 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -21,7 +21,7 @@ accum(x, y) = accum(x, y, zs...) = accum(accum(x, y), zs...) -accum(x::Tuple, ys::Tuple...) = accum.(x, ys...) +accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...) @generated function accum(x::NamedTuple, y::NamedTuple) @@ -48,6 +48,7 @@ end @adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) +accum_param(::Context{false}, _, Δ) = Δ @generated function accum_param(cx::Context, x, Δ) isbitstype(x) && return :(Δ) quote diff --git a/test/compiler.jl b/test/compiler.jl index bc37d271e..c5ddf1f38 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,5 +1,5 @@ using Zygote, Test -using Zygote: pullback, @adjoint +using Zygote: pullback, @adjoint, Context macro test_inferred(ex) :(let res = nothing @@ -160,13 +160,18 @@ end @testset "inference for `getproperty`" begin Gaussian = _Gaussian(:getproperty) g = Gaussian(randn(3), randn(3, 3)) - y, back = @inferred pullback(x -> x.m, g) - @test y == getfield(g, :m) - # This type instability is due to the handling of non-bitstypes in `accum_param` + y_explicit, back_explicit = @inferred pullback(x -> x.m, g) + y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g) + @test y_explicit == y_implicit == getfield(g, :m) + + ∇args = ((m = [1.0, 0.0, 0.0], P = nothing),) if VERSION > v"1.7-" - @test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}] + # This type instability is due to the handling of non-bitstypes in `accum_param` + @test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}] + # But the same should infer if implicit parameters are disabled + @test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)] end - @test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),) + @test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s) y, back = pullback(x -> x.m, g) diff --git a/test/features.jl b/test/features.jl index cdfe7329e..d4f68d36b 100644 --- a/test/features.jl +++ b/test/features.jl @@ -476,7 +476,7 @@ end @test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) @test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20 - @test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0 + @test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41 # Array of mutables: @test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3] @@ -490,6 +490,59 @@ end @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) end +@testset "mutable accum_param bugs" begin + mutable struct Mut{T}; x::T; end + struct Imm{T}; x::T; end + + # Indexing a tuple containing a mutable struct gave `nothing` + x1 = (Mut(3.0),) + x2 = (Imm(3.0),) + x3 = (Ref(3.0),) + @test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),) + @test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + i1 = 1 + @test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),) + @test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + + @test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)] + @test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 + + # When `getfield` returns a mutable struct, it gave `nothing`: + x4 = Imm(Mut(4.0)) + x5 = Mut(Mut(4.0)) + x6 = Imm(Imm(4.0)) + @test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0 + @test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41 + + @test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 + @test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41 + + # Check when using implicit parameters, Params cases used to pass: + y1 = [3.0] + y2 = (Mut(y1),) + y3 = (Imm(y1),) + @test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0] + @test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),) + @test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0] + + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) + @test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] + + i1 = 1 + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) + @test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] +end + @testset "NamedTuples" begin @test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),) @test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],) @@ -517,7 +570,7 @@ end @test (x->10*(x => 2)[2])'(100) === nothing @test gradient(x-> (:x => x)[2], 17) == (1,) - + d = Dict(:x=>1.0, :y=>3.0); @test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),) end @@ -546,7 +599,7 @@ end # zip if VERSION >= v"1.5" # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch, - # while on 1.5 - 1.7 it stops early. + # while on 1.5 - 1.7 it stops early. @test gradient(10:14, 1:10) do xs, ys sum([x/y for (x,y) in zip(xs, ys)]) @@ -608,7 +661,7 @@ end # Iterators.Product with enumerate @test gradient([2 3; 4 5]) do xs - sum([x^i+y for (i,x) in enumerate(xs), y in xs]) + sum([x^i+y for (i,x) in enumerate(xs), y in xs]) end == ([8 112; 36 2004],) end