diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index aabc25d3a..4e6b6461e 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -16,6 +16,7 @@ jobs: os: [ubuntu-latest] package: # - {user: dpsanders, repo: ReversePropagation.jl} + - {user: dfdx, repo: Yota.jl} - {user: FluxML, repo: Zygote.jl} # Diffractor needs to run on Julia nightly # include: diff --git a/Project.toml b/Project.toml index a04d6e52d..b33544714 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.42.0" +version = "1.43.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] ChainRulesCore = "1.15.3" @@ -25,6 +26,7 @@ JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" RealDot = "0.1" StaticArrays = "1.2" +StructArrays = "0.6.11" julia = "1.6" [extras] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index e323f7b6d..b314d7be7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad using ChainRulesCore using Compat using Distributed +using GPUArraysCore: AbstractGPUArrayStyle using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS @@ -11,6 +12,7 @@ using Random using RealDot: realdot using SparseArrays using Statistics +using StructArrays # Basically everything this package does is overloading these, so we make an exception # to the normal rule of only overload via `ChainRulesCore.rrule`. @@ -22,6 +24,9 @@ using ChainRulesCore: derivatives_given_output # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} +# StructArrays +include("unzipped.jl") + include("rulesets/Core/core.jl") include("rulesets/Base/utils.jl") @@ -34,6 +39,7 @@ include("rulesets/Base/arraymath.jl") include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") +include("rulesets/Base/broadcast.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6bea6e06c..c10ba6e71 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -72,6 +72,8 @@ function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} return (T(x, y), Complex_pullback) end +@scalar_rule complex(x) true + # `hypot` @scalar_rule hypot(x::Real) sign(x) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl new file mode 100644 index 000000000..be11eb76a --- /dev/null +++ b/src/rulesets/Base/broadcast.jl @@ -0,0 +1,393 @@ +using Base.Broadcast: Broadcast, broadcasted, Broadcasted, BroadcastStyle +const RCR = RuleConfig{>:HasReverseMode} +const TRI_NO = (NoTangent(), NoTangent(), NoTangent()) + +function rrule(::typeof(copy), bc::Broadcasted) + uncopy(Δ) = (NoTangent(), Δ) + return copy(bc), uncopy +end + +# Skip AD'ing through the axis computation +function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted) + uninstantiate(Δ) = (NoTangent(), Δ) + return Broadcast.instantiate(bc), uninstantiate +end + + +##### +##### Split broadcasting +##### + +# For `z = g.(f.(xs))`, this finds `y = f.(x)` eagerly because the rules for either `f` or `g` may need it, +# and we don't know whether re-computing `y` is cheap. +# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.) + +# This rule has `::BroadcastStyle` in part becuase Zygote's generic rule does, to avoid ambiguities. +# It applies one step later in AD, and all args have `broadcastable(x)` thus many have `Ref(x)`, complicating some tests. +# But it also means that the lazy rules below do not need `::RuleConfig{>:HasReverseMode}` just for dispatch. + +function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N} + T = Broadcast.combine_eltypes(f, args) + if T === Bool # TODO use nondifftype here + # 1: Trivial case: non-differentiable output, e.g. `x .> 0` + @debug("split broadcasting trivial", f, T) + bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...) + return f.(args...), bc_trivial_back + elseif T <: Number && may_bc_derivatives(T, f, args...) + # 2: Fast path: use arguments & result to find derivatives. + return split_bc_derivatives(f, args...) + elseif T <: Number && may_bc_forwards(cfg, f, args...) + # 3: Future path: use `frule_via_ad`? + return split_bc_forwards(cfg, f, args...) + else + # 4: Slow path: collect all the pullbacks & apply them later. + return split_bc_pullbacks(cfg, f, args...) + end +end + +# Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast. + +function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N} + TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) + return isconcretetype(TΔ) +end + +_eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcasted(+, [1,2,3], 4.5)) == Any`: +_eltype(bc::Broadcast.Broadcasted) = Broadcast.combine_eltypes(bc.f, bc.args) + +function split_bc_derivatives(f::F, arg) where {F} + @debug("split broadcasting derivative", f) + ys = f.(arg) + function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all + delta = broadcast(unthunk(dys), ys, arg) do dy, y, a + das = only(derivatives_given_output(y, f, a)) + dy * conj(only(das)) # possibly this * should be made nan-safe. + end + return (TRI_NO..., ProjectTo(arg)(delta)) + end + bc_one_back(z::AbstractZero) = (TRI_NO..., z) + return ys, bc_one_back +end +function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} + @debug("split broadcasting derivatives", f, N) + ys = f.(args...) + function bc_many_back(dys) + deltas = unzip_broadcast(unthunk(dys), ys, args...) do dy, y, as... + das = only(derivatives_given_output(y, f, as...)) + map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. + end + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of unzip_broadcast? + return (TRI_NO..., dargs...) + end + bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) + return ys, bc_many_back +end + +# Path 3: Use forward mode, or an `frule` if one exists. +# To allow `args...` we need either chunked forward mode, with `adot::Tuple` perhaps: +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 +# https://github.com/JuliaDiff/Diffractor.jl/pull/54 +# Or else we need to call the `f` multiple times, and maybe that's OK: +# We do know that `f` doesn't have parameters, so maybe it's pure enough, +# and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`. +# We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow. + +function may_bc_forwards(cfg::C, f::F, args::Vararg{Any,N}) where {C,F,N} + Base.issingletontype(F) || return false + N==1 || return false # Could weaken this to 1 differentiable + cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad + TA = map(_eltype, args) + TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA...}, F, TA...}) + return isconcretetype(TF) && TF <: Tuple +end + +split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = split_bc_inner(frule_via_ad, cfg, f, arg) +split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg) +function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} + @debug("split broadcasting forwards", frule_fun, f) + ys, ydots = unzip_broadcast(arg) do a + frule_fun(cfg, (NoTangent(), one(a)), f, a) + end + function back_forwards(dys) + delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a + ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe. + end + return (TRI_NO..., ProjectTo(arg)(delta)) + end + back_forwards(z::AbstractZero) = (TRI_NO..., z) + return ys, back_forwards +end + +# Path 4: The most generic, save all the pullbacks. Can be 1000x slower. +# Since broadcast makes no guarantee about order of calls, and un-fusing +# can change the number of calls, don't bother to try to reverse the iteration. + +function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} + @debug("split broadcasting generic", f, N) + ys3, backs = unzip_broadcast(args...) do a... + rrule_via_ad(cfg, f, a...) + end + function back_generic(dys) + deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + map(unthunk, back(dy)) + end + dargs = map(unbroadcast, args, Base.tail(deltas)) + df = ProjectTo(f)(sum(first(deltas))) + return (NoTangent(), NoTangent(), df, dargs...) + end + back_generic(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) + return ys3, back_generic +end + +# Don't run broadcasting on scalars +function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Number...) where {F} + @debug("split broadcasting scalar", f) + z, back = rrule_via_ad(cfg, f, args...) + return z, dz -> (NoTangent(), NoTangent(), back(dz)...) +end + +##### +##### Fused broadcasting +##### + +# For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice. +# Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s. + +const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} + +##### Arithmetic: +, -, *, ^2, / + +function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) + @debug("broadcasting: plus", length(xs)) + function bc_plus_back(dy_raw) + dy = unthunk(dy_raw) + return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3 + end + return broadcasted(+, xs...), bc_plus_back +end + +function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) + @debug("broadcasting: minus 2") + function bc_minus_back(dz_raw) + dz = unthunk(dz_raw) + return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz))) + end + return broadcasted(-, x, y), bc_minus_back +end + +function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) + @debug("broadcasting: minus 1") + bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy)) + return broadcasted(-, x), bc_minus_back +end + +function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) + @debug("broadcasting: times") + function bc_times_back(Δraw) + Δ = unthunk(Δraw) + return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) + end + return broadcasted(*, x, y), bc_times_back +end +_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) # this case probably isn't better than generic +_back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this is why the rule exists +_back_star(x::Bool, y, Δ) = NoTangent() +_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x) + +# This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities. +function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...) + @debug("broadcasting: times", 2 + length(zs)) + xy, back1 = rrule(broadcasted, *, x, y) + xyz, back2 = rrule(broadcasted, *, xy, zs...) + function bc_times3_back(dxyz) + _, _, dxy, dzs... = back2(dxyz) + _, _, dx, dy = back1(dxy) + return (NoTangent(), NoTangent(), dx, dy, dzs...) + end + xyz, bc_times3_back +end + +function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) + @debug("broadcasting: square") + function bc_square_back(dy_raw) + dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x)) + return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) + end + return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back +end + +function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) + @debug("broadcasting: divide") + # z = broadcast(/, x, y) + z = broadcasted(/, x, y) + function bc_divide_back(dz_raw) + dz = unthunk(dz_raw) + dx = @thunk unbroadcast(x, dz ./ conj.(y)) + # dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here + dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast + return (NoTangent(), NoTangent(), dx, dy) + end + return z, bc_divide_back +end + +# For the same functions, send accidental broadcasting over numbers directly to `rrule`. +# (Could perhaps move all to @scalar_rule?) + +function _prepend_zero((y, back)) + extra_back(dy) = (NoTangent(), back(dy)...) + return y, extra_back +end + +rrule(::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity +rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number, zs::Number...) = rrule(*, x, y, zs...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = + rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero + +##### Identity, number types + +rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity + +function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number} + @debug("broadcasting: type", T) + bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(T, x), bc_type_back +end +rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero + +function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast) + @debug("broadcasting: float") + bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(float, x), bc_float_back +end +rrule(::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero + +##### Complex: conj, real, imag + +for conj in [:conj, :adjoint] # identical as we know eltype <: Number + @eval begin + function rrule(::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast) + bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx))) + return broadcasted($conj, x), bc_conj_back + end + rrule(::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero + rrule(::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + # This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule. + # Could upgrade to infer eltype of the `Broadcasted`? + end +end + +function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast) + @debug("broadcasting: real") + bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz)))) + return broadcasted(real, x), bc_real_back +end +rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + +function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast) + @debug("broadcasting: imag") + bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz)))) + return broadcasted(imag, x), bc_imag_back +end +rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero +function rrule(::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real}) + @debug("broadcasting: imag(real)") + bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent()) + return broadcasted(imag, x), bc_imag_back_2 +end + +function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast) + @debug("broadcasting: complex") + bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(complex, x), bc_complex_back +end +rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero + +##### +##### Shape fixing +##### + +# When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape. +# It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient. + +function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) + N = ndims(dx) + if length(x) == length(dx) + ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors + else + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims` + ProjectTo(x)(sum(dx; dims)) + end +end +unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx + +function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} + val = if length(x) == length(dx) + dx + else + sum(dx; dims=2:ndims(dx)) + end + eltype(val) <: AbstractZero && return NoTangent() + return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent +end +unbroadcast(x::Tuple, dx::AbstractZero) = dx + +# Scalar types + +unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) + +function unbroadcast(x::T, dx) where {T<:Tuple{Any}} + p1 = ProjectTo(only(x)) + p1 isa ProjectTo{<:AbstractZero} && return NoTangent() + dx1 = p1(sum(dx)) + dx1 isa AbstractZero && return dx1 + return Tangent{T}(dx1) +end +unbroadcast(x::Tuple{Any}, dx::AbstractZero) = dx + +function unbroadcast(x::Base.RefValue, dx) + p1 = ProjectTo(x.x) + p1 isa ProjectTo{<:AbstractZero} && return NoTangent() + dx1 = p1(sum(dx)) + dx1 isa AbstractZero && return dx1 + return Tangent{typeof(x)}(; x = dx1) +end +unbroadcast(x::Base.RefValue, dx::AbstractZero) = dx + +# Zero types + +unbroadcast(::Bool, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx::AbstractZero) = dx # ambiguity +unbroadcast(::Val, dx) = NoTangent() + +function unbroadcast(f::Function, df) + Base.issingletontype(typeof(f)) && return NoTangent() + return sum(df) +end + +##### +##### For testing +##### + +function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f_args...) + tmp = rrule(cfg, broadcasted, f_args...) + isnothing(tmp) && return nothing + y, back = tmp + return _maybe_copy(y), back +end +function rrule(::typeof(copy∘broadcasted), f_args...) + tmp = rrule(broadcasted, f_args...) + isnothing(tmp) && return nothing + y, back = tmp + return _maybe_copy(y), back +end + +_maybe_copy(y) = copy(y) +_maybe_copy(y::Tuple) = y diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 5a9a9b08c..d8afb630a 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -167,6 +167,16 @@ let @scalar_rule x + y (true, true) @scalar_rule x - y (true, -1) @scalar_rule x / y (one(x) / y, -(Ω / y)) + + ## many-arg + + function frule((_, Δx, Δy...), ::typeof(+), x::Number, ys::Number...) + +(x, ys...), +(Δx, Δy...) + end + + function rrule(::typeof(+), x::Number, ys::Number...) + plus_back(dz) = (NoTangent(), dz, map(Returns(dz), ys)...) + +(x, ys...), plus_back + end ## power # literal_pow is in base.jl @@ -276,6 +286,10 @@ let return Ω4, times_pullback4 end rrule(::typeof(*), x::Number) = rrule(identity, x) + + # This is used to choose a faster path in some broadcasting operations: + ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number) = tuple((y', x')) + ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = tuple((y'z', x'z', x'y')) end # fastable_ast # Rewrite everything to use fast_math functions, including the type-constraints @@ -292,8 +306,8 @@ let "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" * join(non_transformed_definitions, "\n") ) - # This error() may not play well with Revise. But a wanring @error does: - # @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions + # This error() may not play well with Revise. But a wanring @error does, we should change it: + @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions end eval(fast_ast) diff --git a/src/unzipped.jl b/src/unzipped.jl new file mode 100644 index 000000000..fe5875e6f --- /dev/null +++ b/src/unzipped.jl @@ -0,0 +1,162 @@ +##### +##### broadcast +##### + +""" + unzip_broadcast(f, args...) + +For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`, +but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting. + +# Examples +``` +julia> using ChainRules: unzip_broadcast, unzip + +julia> unzip_broadcast(x -> (x,2x), 1:3) +([1, 2, 3], [2, 4, 6]) + +julia> mats = @btime unzip_broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB + min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB) + +julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples + min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB) +true +``` +""" +function unzip_broadcast(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) + bcs = Broadcast.BroadcastStyle(typeof(bc)) + if bcs isa AbstractGPUArrayStyle + # This is a crude way to allow GPU arrays, not currently tested, TODO. + # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150 + return unzip(broadcast(f, args...)) + elseif bcs isa Broadcast.AbstractArrayStyle + return StructArrays.components(StructArray(bc)) + else + return unzip(broadcast(f, args...)) # e.g. tuples + end + # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`? +end + +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_broadcast), f::F, args...) where {F} + y, back = rrule_via_ad(cfg, broadcast, f, args...) + z = unzip(y) + function untuplecast(dz) + # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) + dy = broadcast(tuple, map(unthunk, dz)...) + db, df, dargs... = back(dy) + return (db, sum(df), map(unbroadcast, args, dargs)...) + end + untuplecast(dz::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(dz), args)) + return z, untuplecast +end + +# This is for testing, but the tests using it don't work. +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broadcast), f, args...) + y, back = rrule(cfg, unzip_broadcast, f, args...) + return collect(y), back +end + +##### +##### map +##### + +# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`, +# will be useful for the gradient of `map` etc. + + +##### +##### unzip +##### + +""" + unzip(A) + +Converts an array of tuples into a tuple of arrays. +Eager. Will work by `reinterpret` when possible. + +```jldoctest +julia> ChainRules.unzip([(1,2), (30,40), (500,600)]) # makes two new Arrays: +([1, 30, 500], [2, 40, 600]) + +julia> typeof(ans) +Tuple{Vector{Int64}, Vector{Int64}} + +julia> ChainRules.unzip([(1,nothing) (3,nothing) (5,nothing)]) # this can reinterpret: +([1 3 5], [nothing nothing nothing]) + +julia> ans[1] +1×3 reinterpret(Int64, ::Matrix{Tuple{Int64, Nothing}}): + 1 3 5 +``` +""" +function unzip(xs::AbstractArray) + x1 = first(xs) + x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) + N = length(x1) + return unzip(xs, Val(N)) # like Zygote's unzip. Here this is the fallback case. +end + +@generated function unzip(xs, ::Val{N}) where {N} + each = [:(map($(Get(i)), xs)) for i in 1:N] + Expr(:tuple, each...) +end + +unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy + +@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} + each = if count(!Base.issingletontype, Ts.parameters) < 2 + # good case, no copy of data, some trivial arrays + [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] + else + [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] + end + Expr(:tuple, each...) +end + +""" + unzip(t) + +Also works on a tuple of tuples: + +```jldoctest +julia> unzip(((1,2), (30,40), (500,600))) +((1, 30, 500), (2, 40, 600)) +``` +""" +function unzip(xs::Tuple) + x1 = first(xs) + x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays or tuples of tuples")) + return ntuple(i -> map(Get(i), xs),length(x1)) +end + +struct Get{i} end +Get(i) = Get{Int(i)}() +(::Get{i})(x) where {i} = x[i] + +function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray{T}) where {T <: Tuple} + function rezip(dy) + dxs = broadcast(xs, unthunk.(dy)...) do x, ys... + ProjectTo(x)(Tangent{T}(ys...)) + end + return (NoTangent(), dxs) + end + rezip(dz::AbstractZero) = (NoTangent(), dz) + return unzip(xs), rezip +end + +function ChainRulesCore.rrule(::typeof(unzip), xs::Tuple) + function rezip_2(dy) + dxs = broadcast(xs, unthunk.(dy)...) do x, ys... + Tangent{typeof(x)}(ys...) + end + return (NoTangent(), ProjectTo(xs)(dxs)) + end + rezip_2(dz::AbstractZero) = (NoTangent(), dz) + return unzip(xs), rezip_2 +end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a65881747..36452da1e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -77,6 +77,7 @@ for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(real, x) test_scalar(imag, x) + test_scalar(complex, x) test_scalar(hypot, x) test_scalar(adjoint, x) end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl new file mode 100644 index 000000000..68d47a7d4 --- /dev/null +++ b/test/rulesets/Base/broadcast.jl @@ -0,0 +1,176 @@ +using Base.Broadcast: broadcasted + +if VERSION < v"1.7" + Base.ndims(::Type{<:AbstractArray{<:Any,N}}) where {N} = N +end +BS0 = Broadcast.BroadcastStyle(Float64) +BS1 = Broadcast.BroadcastStyle(Vector) # without ndims method, error on 1.6 +BS2 = Broadcast.BroadcastStyle(Matrix) + +BT1 = Broadcast.BroadcastStyle(Tuple) + +@testset "Broadcasting" begin + @testset "split 1: trivial path" begin + # test_rrule(copy∘broadcasted, >, rand(3), rand(3)) # MethodError: no method matching eps(::UInt64) inside FiniteDifferences + y1, bk1 = rrule(CFG, copy∘broadcasted, BS1, >, rand(3), rand(3)) + @test y1 isa AbstractArray{Bool} + @test all(d -> d isa AbstractZero, bk1(99)) + + y2, bk2 = rrule(CFG, copy∘broadcasted, BT1, isinteger, Tuple(rand(3))) + @test y2 isa Tuple{Bool,Bool,Bool} + @test all(d -> d isa AbstractZero, bk2(99)) + end + + @testset "split 2: derivatives" begin + test_rrule(copy∘broadcasted, BS1, log, rand(3) .+ 1) + test_rrule(copy∘broadcasted, BT1, log, Tuple(rand(3) .+ 1)) + + # Two args uses StructArrays + test_rrule(copy∘broadcasted, BS1, atan, rand(3), rand(3)) + test_rrule(copy∘broadcasted, BS2, atan, rand(3), rand(4)') + test_rrule(copy∘broadcasted, BS1, atan, rand(3), rand()) + test_rrule(copy∘broadcasted, BT1, atan, rand(3), Tuple(rand(1))) + test_rrule(copy∘broadcasted, BT1, atan, Tuple(rand(3)), Tuple(rand(3)), check_inferred = VERSION > v"1.7") + + # test_rrule(copy∘broadcasted, *, BS1, rand(3), Ref(rand())) # don't know what I was testing + end + + @testset "split 3: forwards" begin + # In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else. + test_rrule(copy∘broadcasted, BS1, flog, rand(3)) + test_rrule(copy∘broadcasted, BS1, flog, rand(3) .+ im) + # Also, `sin∘cos` may use this path as CFG uses frule_via_ad + # TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255 + end + + @testset "split 4: generic" begin + test_rrule(copy∘broadcasted, BS1, sin∘cos, rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS2, sin∘atan, rand(3), rand(3)', check_inferred=false) + test_rrule(copy∘broadcasted, BS1, sin∘atan, rand(), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, ^, rand(3), 3.0, check_inferred=false) # NoTangent vs. Union{NoTangent, ZeroTangent} + # Many have quite small inference failures, like: + # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Float64} does not match inferred + # return type Tuple{NoTangent, Union{NoTangent, ZeroTangent}, Vector{Float64}, Float64} + + # From test_helpers.jl + test_rrule(copy∘broadcasted, BS1, Multiplier(rand()), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS2, Multiplier(rand()), rand(3), rand(4)', check_inferred=false) # Union{ZeroTangent, Tangent{Multiplier{... + @test_skip test_rrule(copy∘broadcasted, BS1, Multiplier(rand()), rand(3), 5.0im, check_inferred=false) # ProjectTo(f) fails to remove the imaginary part of Multiplier's gradient + test_rrule(copy∘broadcasted, BS1, make_two_vec, rand(3), check_inferred=false) + + # Non-diff components -- note that with BroadcastStyle, Ref is from e.g. Broadcast.broadcastable(nothing) + test_rrule(copy∘broadcasted, BS2, first∘tuple, rand(3), Ref(:sym), rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, BS2, last∘tuple, rand(3), Ref(nothing), rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, BS1, |>, rand(3), Ref(sin), check_inferred=false) + _call(f, x...) = f(x...) + test_rrule(copy∘broadcasted, BS2, _call, Ref(atan), rand(3), rand(4)', check_inferred=false) + + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], [3,1], check_inferred=false) + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], (3,1), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], Ref(CartesianIndex(2)), check_inferred=false) + test_rrule(copy∘broadcasted, BT1, getindex, Tuple([rand(3) for _ in 1:2]), (3,1), check_inferred=false) + test_rrule(copy∘broadcasted, BT1, getindex, Tuple([Tuple(rand(3)) for _ in 1:2]), (3,1), check_inferred=false) + + # Protected by Ref/Tuple: + test_rrule(copy∘broadcasted, BS1, *, rand(3), Ref(rand(2)), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, conj∘*, rand(3), Ref(rand() + im), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, conj∘*, rand(3), Ref(rand(2) .+ im), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, /, (rand(2),), rand(3), check_inferred=false) + end + + @testset "fused rules" begin + @testset "arithmetic" begin + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(4)') + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) + @gpu test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) + @gpu test_rrule(copy∘broadcasted, +, rand(3), true) + @gpu_broken test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) + + @gpu test_rrule(copy∘broadcasted, -, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, -, rand(3), rand(4)') + @gpu test_rrule(copy∘broadcasted, -, rand(3)) + test_rrule(copy∘broadcasted, -, Tuple(rand(3))) + + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand()) + @gpu test_rrule(copy∘broadcasted, *, rand(), rand(3)) + + test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand(3) .+ 2im) + test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand() + 3im) + test_rrule(copy∘broadcasted, *, rand() + im, rand(3) .+ 4im) + + @test_skip test_rrule(copy∘broadcasted, *, im, rand(3)) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) + @test_skip test_rrule(copy∘broadcasted, *, rand(3), im) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) + y4, bk4 = rrule(CFG, copy∘broadcasted, *, im, [1,2,3.0]) + @test y4 == [im, 2im, 3im] + @test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7] + + # These two test vararg rrule * rule: + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) + @gpu_broken test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') + # GPU error from dot(x::JLArray{Float32, 1}, y::JLArray{ComplexF32, 2}) + + @gpu test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) + @gpu test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) + + @gpu test_rrule(copy∘broadcasted, /, rand(3), rand()) + @gpu test_rrule(copy∘broadcasted, /, rand(3) .+ im, rand() + 3im) + end + @testset "identity etc" begin + test_rrule(copy∘broadcasted, identity, rand(3)) + + test_rrule(copy∘broadcasted, Float32, rand(3), rtol=1e-4) + test_rrule(copy∘broadcasted, ComplexF32, rand(3), rtol=1e-4) + + test_rrule(copy∘broadcasted, float, rand(3)) + end + @testset "complex" begin + test_rrule(copy∘broadcasted, conj, rand(3)) + test_rrule(copy∘broadcasted, conj, rand(3) .+ im) + test_rrule(copy∘broadcasted, adjoint, rand(3)) + test_rrule(copy∘broadcasted, adjoint, rand(3) .+ im) + + test_rrule(copy∘broadcasted, real, rand(3)) + test_rrule(copy∘broadcasted, real, rand(3) .+ im) + + test_rrule(copy∘broadcasted, imag, rand(3)) + test_rrule(copy∘broadcasted, imag, rand(3) .+ im .* rand.()) + + test_rrule(copy∘broadcasted, complex, rand(3)) + end + end + + @testset "scalar rules" begin + @testset "generic" begin + test_rrule(copy∘broadcasted, BS0, sin, rand()) + test_rrule(copy∘broadcasted, BS0, atan, rand(), rand()) + # test_rrule(copy∘broadcasted, BS0, >, rand(), rand()) # DimensionMismatch from FiniteDifferences + end + # Functions with lazy broadcasting rules: + @testset "arithmetic" begin + test_rrule(copy∘broadcasted, +, rand(), rand(), rand()) + test_rrule(copy∘broadcasted, +, rand()) + test_rrule(copy∘broadcasted, -, rand(), rand()) + test_rrule(copy∘broadcasted, -, rand()) + test_rrule(copy∘broadcasted, *, rand(), rand()) + test_rrule(copy∘broadcasted, *, rand(), rand(), rand(), rand()) + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(), Val(2)) + test_rrule(copy∘broadcasted, /, rand(), rand()) + end + @testset "identity etc" begin + test_rrule(copy∘broadcasted, identity, rand()) + test_rrule(copy∘broadcasted, Float32, rand(), rtol=1e-4) + test_rrule(copy∘broadcasted, float, rand()) + end + @testset "complex" begin + test_rrule(copy∘broadcasted, conj, rand()) + test_rrule(copy∘broadcasted, conj, rand() + im) + test_rrule(copy∘broadcasted, real, rand()) + test_rrule(copy∘broadcasted, real, rand() + im) + test_rrule(copy∘broadcasted, imag, rand()) + test_rrule(copy∘broadcasted, imag, rand() + im) + test_rrule(copy∘broadcasted, complex, rand()) + end + end +end \ No newline at end of file diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 4fed0988f..89f41c933 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -2,8 +2,6 @@ Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights) struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end -const CFG = ChainRulesTestUtils.ADviaRuleConfig() - @testset "Reductions" begin @testset "sum(::Tuple)" begin test_frule(sum, Tuple(rand(5))) @@ -85,6 +83,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig() # inference fails for array of arrays test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) + test_rrule(sum, norm, collect.(eachcol(rand(3,4))); check_inferred=false) # dims kwarg test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) diff --git a/test/runtests.jl b/test/runtests.jl index 24c1d85b9..9ac5c5981 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,7 +43,7 @@ else end @testset "ChainRules" begin # One overall @testset ensures it keeps going after failures - include("test_helpers.jl") + include("test_helpers.jl") # This can't be skipped println() test_method_tables() # Check the global method tables are consistent @@ -57,6 +57,9 @@ end include_test("rulesets/Base/indexing.jl") include_test("rulesets/Base/mapreduce.jl") include_test("rulesets/Base/sort.jl") + include_test("rulesets/Base/broadcast.jl") + + include_test("unzipped.jl") # used primarily for broadcast println() diff --git a/test/test_helpers.jl b/test/test_helpers.jl index b347e789b..e06759054 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -98,6 +98,7 @@ function _gpu_test(::typeof(frule), f::Function, g::Function, xs...; kw...) # s _gpu_test(frule, xdots, f, g, xs...; kw...) end +const CFG = ChainRulesTestUtils.TestConfig() """ Multiplier(x) @@ -176,6 +177,14 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x) return make_two_vec(x), make_two_vec_pullback end +"A version of `*` with only an `frule` defined" +fstar(A, B) = A * B +ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(fstar), A, B) = A * B, muladd(ΔA, B, A * ΔB) + +"A version of `log` with only an `frule` defined" +flog(x::Number) = log(x) +ChainRulesCore.frule((_, Δx), ::typeof(flog), x::Number) = log(x), inv(x) * Δx + @testset "test_helpers.jl" begin @testset "Multiplier" begin @@ -204,5 +213,11 @@ end @testset "make_two_vec" begin test_rrule(make_two_vec, 1.5) end + + @testset "fstar, flog" begin + test_frule(fstar, 1.2, 3.4 + 5im) + test_frule(flog, 6.7) + test_frule(flog, 8.9 + im) + end end diff --git a/test/unzipped.jl b/test/unzipped.jl new file mode 100644 index 000000000..97aaa23f5 --- /dev/null +++ b/test/unzipped.jl @@ -0,0 +1,97 @@ + +using ChainRules: unzip_broadcast, unzip #, unzip_map + +@testset "unzipped.jl" begin + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, + @test_throws Exception fun(sqrt, 1:3) + + @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) + @test fun(tuple, [1, 10, 100]) == ([1, 10, 100],) + @test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3)) + @test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3)) + @test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3)) + + if contains(string(fun), "map") + @test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6]) + else + @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + end + + if contains(string(fun), "map") + @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) + else + @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) + @test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7)) + @test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8)) + end + @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + end + + @testset "rrules" begin + # These exist to allow for second derivatives + + # test_rrule(collect∘unzip_broadcast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any} + + y1, bk1 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4,5,6.0]) + @test y1 == ([1, 2, 3], [4, 5, 6]) + @test bk1(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + + # bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences + + y2, bk2 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0], 6.0) + @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) + @test bk2(y2)[5] ≈ 36 + + # y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) + # @test y4 == ([1, 2, 3], [4, 5, 6]) + # @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + end + + @testset "unzip" begin + @test unzip([(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6]) + @test unzip(Any[(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6]) + + @test unzip([(nothing,2), (3,4), (5,6)]) == ([nothing, 3, 5], [2, 4, 6]) + @test unzip([(missing,2), (missing,4), (missing,6)])[2] isa Base.ReinterpretArray + + @test unzip([(1,), (3,), (5,)]) == ([1, 3, 5],) + @test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray + + @test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6)) + + # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2 + + y, bk = rrule(unzip, [(1,2), (3,4), (5,6)]) + @test y == ([1, 3, 5], [2, 4, 6]) + @test bk(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] isa Vector{<:Tangent{<:Tuple}} + + y3, bk3 = rrule(unzip, [(1,ZeroTangent()), (3,ZeroTangent()), (5,ZeroTangent())]) + @test y3 == ([1, 3, 5], [ZeroTangent(), ZeroTangent(), ZeroTangent()]) + dx3 = bk3(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] + @test dx3 isa Vector{<:Tangent{<:Tuple}} + @test Tuple(dx3[1]) == (1.0, NoTangent()) + + y5, bk5 = rrule(unzip, ((1,2), (3,4), (5,6))) + @test y5 == ((1, 3, 5), (2, 4, 6)) + @test bk5(y5)[2] isa Tangent{<:Tuple} + @test Tuple(bk5(y5)[2][2]) == (3, 4) + dx5 = bk5(((1,10,100), ZeroTangent())) + @test dx5[2] isa Tangent{<:Tuple} + @test Tuple(dx5[2][2]) == (10, ZeroTangent()) + end + + @testset "JLArray tests" begin # fake GPU testing + (y1, y2), bk = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0]) + (y1jl, y2jl), bk_jl = rrule(CFG, unzip_broadcast, tuple, jl([1,2,3.0]), jl([4 5.0])) + @test y1 == Array(y1jl) + # TODO invent some tests of this rrule's pullback function + + @test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6])) + + @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6]) + @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray + + @test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5]) + @test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray + end +end \ No newline at end of file