diff --git a/Project.toml b/Project.toml index e44739e88..64b86cc45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.43.2" +version = "1.44.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 7b6568e09..db25537b2 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -196,9 +196,14 @@ end # This is called by e.g. `iterate(1:0.1:2)`, # and fixes https://github.com/FluxML/Zygote.jl/issues/1247 +# Only needs to accept AbstractRange, but AbstractVector makes testing easier. -function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i::Integer) - return rrule_via_ad(cfg, getindex, x, i) +function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return Base.unsafe_getindex(x, i), getindex(ẋ, i) +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return rrule_via_ad(cfg, getindex, x, i) end ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 695f06010..0d5b09398 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -177,6 +177,11 @@ end test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) end +@testset "unsafe_getindex" begin + test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) + test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) +end + @testset "eachslice" begin # Testing eachrow not collect∘eachrow leads to errors, e.g. # test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195