Skip to content

Commit

Permalink
test bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 7, 2022
1 parent 474fbf2 commit f62a756
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@ end

function rrule(cfg::RCR, ::typeof(copybroadcasted), f_args...)
tmp = rrule(cfg, broadcasted, f_args...)
isnothing(tmp) && throw("rrule gave nothing")
isnothing(tmp) && return nothing
y, back = tmp
return _maybe_copy(y), back
end
function rrule(::typeof(copybroadcasted), f_args...)
tmp = rrule(broadcasted, f_args...)
isnothing(tmp) && throw("rrule gave nothing")
isnothing(tmp) && return nothing
y, back = tmp
return _maybe_copy(y), back
end
Expand Down
14 changes: 6 additions & 8 deletions test/unzipped.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

using ChainRules: unzip_broadcast, unzip #, unzip_map

@testset "unzip_broadcast.jl" begin
@testset "unzipped.jl" begin
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast] # unzip_map,
@test_throws Exception fun(sqrt, 1:3)

Expand All @@ -16,10 +16,8 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
else
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
end

if fun == unzip_map
@test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
elseif fun == unzipmap

if fun == unzipmap
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
else
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
Expand All @@ -44,9 +42,9 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
@test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6])
@test bk2(y2)[5] 36

y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0])
@test y4 == ([1, 2, 3], [4, 5, 6])
@test bk4(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
# y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0])
# @test y4 == ([1, 2, 3], [4, 5, 6])
# @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100]
end

@testset "unzip" begin
Expand Down

0 comments on commit f62a756

Please sign in to comment.