Skip to content

Commit

Permalink
test unsafe_getindex
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 10, 2022
1 parent c1ebf3f commit 9e1aa8c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.43.2"
version = "1.44.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 7 additions & 2 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,14 @@ end

# This is called by e.g. `iterate(1:0.1:2)`,
# and fixes https://github.com/FluxML/Zygote.jl/issues/1247
# Only needs to accept AbstractRange, but AbstractVector makes testing easier.

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i::Integer)
return rrule_via_ad(cfg, getindex, x, i)
function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
return Base.unsafe_getindex(x, i), getindex(ẋ, i)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
return rrule_via_ad(cfg, getindex, x, i)
end

#####
Expand Down
5 changes: 5 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ end
test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3)
end

@testset "unsafe_getindex" begin
test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3)
test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3)
end

@testset "eachslice" begin
# Testing eachrow not collect∘eachrow leads to errors, e.g.
# test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195
Expand Down

0 comments on commit 9e1aa8c

Please sign in to comment.