From 9e1aa8c9a0df72278987dd96b3dade9599ba5581 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 2 Aug 2022 23:09:46 -0400 Subject: [PATCH] test unsafe_getindex --- Project.toml | 2 +- src/rulesets/Base/indexing.jl | 9 +++++++-- test/rulesets/Base/indexing.jl | 5 +++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index e44739e88..64b86cc45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.43.2" +version = "1.44.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 7b6568e09..db25537b2 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -196,9 +196,14 @@ end # This is called by e.g. `iterate(1:0.1:2)`, # and fixes https://github.com/FluxML/Zygote.jl/issues/1247 +# Only needs to accept AbstractRange, but AbstractVector makes testing easier. -function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i::Integer) - return rrule_via_ad(cfg, getindex, x, i) +function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return Base.unsafe_getindex(x, i), getindex(ẋ, i) +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return rrule_via_ad(cfg, getindex, x, i) end ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 695f06010..0d5b09398 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -177,6 +177,11 @@ end test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) end +@testset "unsafe_getindex" begin + test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) + test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) +end + @testset "eachslice" begin # Testing eachrow not collect∘eachrow leads to errors, e.g. # test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195