From bcfaf79966e91af32eefa8c5663945cfd87a1141 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 20:02:09 -0400 Subject: [PATCH 01/24] =?UTF-8?q?Fix=20`=E2=88=87eachslice`=20output=20arr?= =?UTF-8?q?ay=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/Base/indexing.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..8c5865a98 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -267,7 +267,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x)), axes(x))) end - T = promote_type(eltype(dys[i1]), eltype(x)) + T = promote_type(eltype.(dys)...) # The whole point of this gradient is that we can allocate one `dx` array: dx = similar(x, T, axes(x)) for i in axes(x, dim) @@ -282,8 +282,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} end ∇eachslice(dys::AbstractZero, x::AbstractArray, vd::Val{dim}) where {dim} = dys -_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx))) -_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx) +_zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx))) function rrule(::typeof(∇eachslice), dys, x, vd::Val) function ∇∇eachslice(dz_raw) From 061685ae94d281eed379b152b1e00f679e2fac87 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 21:19:34 -0400 Subject: [PATCH 02/24] Add test --- test/rulesets/Base/indexing.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..19da92022 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -227,6 +227,8 @@ end # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = 3)) test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = (2,))) + + test_rrule(collect∘eachslice, FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); check_inferred = false, fkwargs = (; dims = 3)) end # Make sure pulling back an array that mixes some AbstractZeros in works right From ca652dbdf5f656eb5cfb1453ff3b1a52de71d4e6 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 21:40:28 -0400 Subject: [PATCH 03/24] format suggestion --- test/rulesets/Base/indexing.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 19da92022..65a2ad47b 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -228,7 +228,12 @@ end test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = 3)) test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = (2,))) - test_rrule(collect∘eachslice, FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); check_inferred = false, fkwargs = (; dims = 3)) + test_rrule( + collect∘eachslice, + FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); + check_inferred = false, + fkwargs = (; dims = 3) + ) end # Make sure pulling back an array that mixes some AbstractZeros in works right From 99d43c0fd86eac1a343725b7a17919932cc307a9 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 21:42:03 -0400 Subject: [PATCH 04/24] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a679c914d..50f5edc2c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.70.0" +version = "1.71.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From ff93e2b6f37301e0d02fe0e199ca846776dcc9de Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 22:28:34 -0400 Subject: [PATCH 05/24] Formatting fix --- test/rulesets/Base/indexing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 65a2ad47b..4d4fee93b 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -231,8 +231,8 @@ end test_rrule( collect∘eachslice, FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); - check_inferred = false, - fkwargs = (; dims = 3) + check_inferred=false, + fkwargs=(; dims=3) ) end From aa58c9b60a014d2a4c0d6574e5cb492c5de4692f Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 23:33:13 -0400 Subject: [PATCH 06/24] Add right test --- test/rulesets/Base/indexing.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 4d4fee93b..5d43448d0 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -242,6 +242,10 @@ end @test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64} @test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0]) + _, back = ChainRules.rrule(eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims = 3) + @test back([fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == + (NoTangent(), [fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)]) + # Second derivative rule test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1)) test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2)) From 45f00a5168ac64ce2eb147c0ae853095e071cc63 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 23:34:27 -0400 Subject: [PATCH 07/24] Add promotion rules (belong in ChainRulesCore) --- src/rulesets/Base/indexing.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 8c5865a98..e2523a8fa 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -284,6 +284,10 @@ end _zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx))) +# Belong in ChainRulesCore +Base.promote_type(T::Type{<:Number}, S::Type{<:AbstractZero}) = T +Base.promote_type(T::Type{<:AbstractZero}, S::Type{<:Number}) = S + function rrule(::typeof(∇eachslice), dys, x, vd::Val) function ∇∇eachslice(dz_raw) dz = unthunk(dz_raw) From b5396f76665c57f18849c62c4136f120d6668cdb Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 23:37:14 -0400 Subject: [PATCH 08/24] format fixes --- test/rulesets/Base/indexing.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 5d43448d0..1f627efd7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -229,10 +229,10 @@ end test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = (2,))) test_rrule( - collect∘eachslice, + collect ∘ eachslice, FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); check_inferred=false, - fkwargs=(; dims=3) + fkwargs=(; dims=3), ) end @@ -243,8 +243,9 @@ end @test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0]) _, back = ChainRules.rrule(eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims = 3) - @test back([fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == - (NoTangent(), [fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)]) + @test back([fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( + NoTangent(), [fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)] + ) # Second derivative rule test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1)) From 940d63f866c037e15f2515d3cdf23cada9fe91ad Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 23:55:44 -0400 Subject: [PATCH 09/24] promote_type -> promote_rule, add eltype for NoTangent --- src/rulesets/Base/indexing.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index e2523a8fa..430fd6f5b 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -285,8 +285,9 @@ end _zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx))) # Belong in ChainRulesCore -Base.promote_type(T::Type{<:Number}, S::Type{<:AbstractZero}) = T -Base.promote_type(T::Type{<:AbstractZero}, S::Type{<:Number}) = S +Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T +Base.promote_rule(T::Type{<:AbstractZero}, S::Type{<:Number}) = S +Base.eltype(::Type{NoTangent}) = NoTangent function rrule(::typeof(∇eachslice), dys, x, vd::Val) function ∇∇eachslice(dz_raw) From 41759242c1e2f1e915e91677d83176b69f807aed Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 00:10:02 -0400 Subject: [PATCH 10/24] format fix --- test/rulesets/Base/indexing.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 1f627efd7..317b2002c 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -242,7 +242,9 @@ end @test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64} @test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0]) - _, back = ChainRules.rrule(eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims = 3) + _, back = ChainRules.rrule( + eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims = 3 + ) @test back([fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( NoTangent(), [fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)] ) From 0c0ce2b31edf95ea18d5e0e50a89a8a41efd7e4c Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 00:13:47 -0400 Subject: [PATCH 11/24] format fix --- test/rulesets/Base/indexing.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 317b2002c..15dde2085 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -243,10 +243,10 @@ end @test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0]) _, back = ChainRules.rrule( - eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims = 3 + eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims=3 ) - @test back([fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( - NoTangent(), [fill(Tangent{Any}(; x = 0.0, y = 1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)] + @test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( + NoTangent(), [fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)] ) # Second derivative rule From 250fec9b74d6997a74a7f44f33d1ed63dc70d7d3 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 01:32:52 -0400 Subject: [PATCH 12/24] turn off inference check --- test/rulesets/Base/indexing.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 15dde2085..5ffd99cae 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -217,16 +217,16 @@ end # DimensionMismatch("second dimension of A, 6, does not match length of x, 5") # Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator) - test_rrule(collect∘eachrow, rand(5)) - test_rrule(collect∘eachrow, rand(3, 4)) + test_rrule(collect∘eachrow, rand(5); check_inferred=false) + test_rrule(collect∘eachrow, rand(3, 4); check_inferred=false) - test_rrule(collect∘eachcol, rand(3, 4)) + test_rrule(collect∘eachcol, rand(3, 4); check_inferred=false) @test_skip test_rrule(collect∘eachcol, Diagonal(rand(5))) # works locally! if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. - test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = 3)) - test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = (2,))) + test_rrule(collect∘eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = 3)) + test_rrule(collect∘eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = (2,))) test_rrule( collect ∘ eachslice, From 15a50508a0b1107943f3fe0608d7f9129a3904e5 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 01:39:01 -0400 Subject: [PATCH 13/24] format fix --- test/rulesets/Base/indexing.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 5ffd99cae..70d015f9b 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -217,16 +217,20 @@ end # DimensionMismatch("second dimension of A, 6, does not match length of x, 5") # Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator) - test_rrule(collect∘eachrow, rand(5); check_inferred=false) - test_rrule(collect∘eachrow, rand(3, 4); check_inferred=false) + test_rrule(collect ∘ eachrow, rand(5); check_inferred=false) + test_rrule(collect ∘ eachrow, rand(3, 4); check_inferred=false) - test_rrule(collect∘eachcol, rand(3, 4); check_inferred=false) - @test_skip test_rrule(collect∘eachcol, Diagonal(rand(5))) # works locally! + test_rrule(collect ∘ eachcol, rand(3, 4); check_inferred=false) + @test_skip test_rrule(collect ∘ eachcol, Diagonal(rand(5))) # works locally! if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. - test_rrule(collect∘eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = 3)) - test_rrule(collect∘eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = (2,))) + test_rrule( + collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = 3) + ) + test_rrule( + collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = (2,)) + ) test_rrule( collect ∘ eachslice, From 26b61c5fee32dbc63ce8a41d3067d941ecee08a7 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 08:32:32 -0400 Subject: [PATCH 14/24] formatter --- test/rulesets/Base/indexing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 70d015f9b..3a23516e5 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -226,10 +226,10 @@ end if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. test_rrule( - collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = 3) + collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs=(; dims=3) ) test_rrule( - collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs = (; dims = (2,)) + collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs=(; dims=(2,)) ) test_rrule( From ccf6570ccca2800e5c9000fc354a0611da201d2a Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 09:55:54 -0400 Subject: [PATCH 15/24] Fix promotion inference --- src/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 430fd6f5b..cae1272a9 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -267,7 +267,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x)), axes(x))) end - T = promote_type(eltype.(dys)...) + T = Base.promote_eltype(dys...) # The whole point of this gradient is that we can allocate one `dx` array: dx = similar(x, T, axes(x)) for i in axes(x, dim) From e72cb1ca3c2c0a1dc4ca2b5d56ddaba03f426d27 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 12:18:53 -0400 Subject: [PATCH 16/24] re-add inference testing --- test/rulesets/Base/indexing.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 3a23516e5..7d5a0dcbf 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -217,19 +217,19 @@ end # DimensionMismatch("second dimension of A, 6, does not match length of x, 5") # Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator) - test_rrule(collect ∘ eachrow, rand(5); check_inferred=false) - test_rrule(collect ∘ eachrow, rand(3, 4); check_inferred=false) + test_rrule(collect ∘ eachrow, rand(5)) + test_rrule(collect ∘ eachrow, rand(3, 4)) - test_rrule(collect ∘ eachcol, rand(3, 4); check_inferred=false) + test_rrule(collect ∘ eachcol, rand(3, 4)) @test_skip test_rrule(collect ∘ eachcol, Diagonal(rand(5))) # works locally! if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. test_rrule( - collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs=(; dims=3) + collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=3) ) test_rrule( - collect ∘ eachslice, rand(3, 4, 5); check_inferred=false, fkwargs=(; dims=(2,)) + collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=(2,)) ) test_rrule( From 948f7aaad98c968901048b31296e4b3adf2f2db5 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 12:19:11 -0400 Subject: [PATCH 17/24] Make test work on earlier Julia versions --- test/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 7d5a0dcbf..88ab3bbe6 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -250,7 +250,7 @@ end eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims=3 ) @test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( - NoTangent(), [fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3);;; fill(ZeroTangent(), 2, 3)] + NoTangent(), cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3), dims = 3) ) # Second derivative rule From b65fde204ee74001649047b369c43b29a18ecb95 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 12:22:04 -0400 Subject: [PATCH 18/24] format --- test/rulesets/Base/indexing.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 88ab3bbe6..7500b36d6 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -225,12 +225,8 @@ end if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. - test_rrule( - collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=3) - ) - test_rrule( - collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=(2,)) - ) + test_rrule(collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=3)) + test_rrule(collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=(2,))) test_rrule( collect ∘ eachslice, @@ -250,7 +246,8 @@ end eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims=3 ) @test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( - NoTangent(), cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3), dims = 3) + NoTangent(), + cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3); dims=3) ) # Second derivative rule From 01a9fba86b12a0d5b8556db560d7f6d34095ae22 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 12:26:07 -0400 Subject: [PATCH 19/24] format --- test/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 7500b36d6..3dbf553f1 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -247,7 +247,7 @@ end ) @test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( NoTangent(), - cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3); dims=3) + cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3); dims=3), ) # Second derivative rule From 1424ccdb422a67a3251acc70af82f5fca79ac9dd Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 16 Sep 2024 21:52:14 -0400 Subject: [PATCH 20/24] Don't check inference on v1.6 --- test/rulesets/Base/indexing.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 3dbf553f1..c3c08b7e2 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -217,10 +217,11 @@ end # DimensionMismatch("second dimension of A, 6, does not match length of x, 5") # Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator) - test_rrule(collect ∘ eachrow, rand(5)) - test_rrule(collect ∘ eachrow, rand(3, 4)) + # Inference on 1.6 sometimes fails, so don't enforce there. + test_rrule(collect ∘ eachrow, rand(5); check_inferred=(VERSION >= v"1.7")) + test_rrule(collect ∘ eachrow, rand(3, 4); check_inferred=(VERSION >= v"1.7")) - test_rrule(collect ∘ eachcol, rand(3, 4)) + test_rrule(collect ∘ eachcol, rand(3, 4); check_inferred=(VERSION >= v"1.7")) @test_skip test_rrule(collect ∘ eachcol, Diagonal(rand(5))) # works locally! if VERSION >= v"1.7" @@ -253,5 +254,5 @@ end # Second derivative rule test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1)) test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2)) - test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3), check_inferred=false) + test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3); check_inferred=(VERSION >= v"1.7")) end From 50f3dcf4d9227ff6fc18bd7d95efee4305310756 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 17 Sep 2024 13:58:13 -0400 Subject: [PATCH 21/24] Moved to ChainRulesCore and bumped compat to new version --- Project.toml | 2 +- src/rulesets/Base/indexing.jl | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 50f5edc2c..5c3a5607e 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" -ChainRulesCore = "1.20" +ChainRulesCore = "1.25" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index cae1272a9..61216bda2 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -284,11 +284,6 @@ end _zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx))) -# Belong in ChainRulesCore -Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T -Base.promote_rule(T::Type{<:AbstractZero}, S::Type{<:Number}) = S -Base.eltype(::Type{NoTangent}) = NoTangent - function rrule(::typeof(∇eachslice), dys, x, vd::Val) function ∇∇eachslice(dz_raw) dz = unthunk(dz_raw) From 71959fa36f43bb8f31d7cbc58f79df24b0227b69 Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Wed, 18 Sep 2024 08:01:53 +0200 Subject: [PATCH 22/24] Update format.yml --- .github/workflows/format.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index f80377a24..6daf12a92 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -12,6 +12,10 @@ concurrency: jobs: format: runs-on: ubuntu-latest + permissions: + contents: read + checks: write + pull-requests: write steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest From 9538e2aaf84278985215b61e3ad5e646aefcf3c3 Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Wed, 18 Sep 2024 08:05:20 +0200 Subject: [PATCH 23/24] Update format.yml --- .github/workflows/format.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 6daf12a92..e8334abca 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -26,6 +26,7 @@ jobs: julia -e 'using JuliaFormatter; format("."; verbose=true)' - uses: reviewdog/action-suggester@v1 with: + github_token: ${{ secrets.GITHUB_TOKEN }} tool_name: JuliaFormatter fail_on_error: true filter_mode: added From ba19c99aa88b3263b3adc5264ab889af88cf3f04 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 18 Sep 2024 09:41:08 -0400 Subject: [PATCH 24/24] format fix --- test/rulesets/Base/indexing.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index c3c08b7e2..f80a37048 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -254,5 +254,11 @@ end # Second derivative rule test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1)) test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2)) - test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3); check_inferred=(VERSION >= v"1.7")) + test_rrule( + ChainRules.∇eachslice, + [rand(2, 3) for _ in 1:4], + rand(2, 3, 4), + Val(3); + check_inferred=(VERSION >= v"1.7"), + ) end