diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index ee12767c6..025e9e936 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -393,13 +393,13 @@ end function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f_args...) tmp = rrule(cfg, broadcasted, f_args...) - isnothing(tmp) && throw("rrule gave nothing") + isnothing(tmp) && return nothing y, back = tmp return _maybe_copy(y), back end function rrule(::typeof(copy∘broadcasted), f_args...) tmp = rrule(broadcasted, f_args...) - isnothing(tmp) && throw("rrule gave nothing") + isnothing(tmp) && return nothing y, back = tmp return _maybe_copy(y), back end diff --git a/test/unzipped.jl b/test/unzipped.jl index 0d616b3f2..ae1ea7a14 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -1,7 +1,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map -@testset "unzip_broadcast.jl" begin +@testset "unzipped.jl" begin @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, @test_throws Exception fun(sqrt, 1:3) @@ -16,10 +16,8 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map else @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) end - - if fun == unzip_map - @test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) - elseif fun == unzip∘map + + if fun == unzip∘map @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) else @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) @@ -44,9 +42,9 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) @test bk2(y2)[5] ≈ 36 - y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) - @test y4 == ([1, 2, 3], [4, 5, 6]) - @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + # y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) + # @test y4 == ([1, 2, 3], [4, 5, 6]) + # @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] end @testset "unzip" begin