From 1696cee621148a86cfeaa5a8b68595547311bdc0 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 1 Jan 2025 16:14:22 +0800 Subject: [PATCH] Prevent type=inferability escaping for rrule of sortslices --- src/rulesets/Base/sort.jl | 7 +++++-- test/rulesets/Base/sort.jl | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 0805da91f..bbd45bbc6 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -60,10 +60,13 @@ end function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) - inds = ntuple(d -> d == dims ? p : (:), ndims(x)) function sortslices_pullback(dy) - return (NoTangent(), ∇getindex(x, unthunk(dy), inds...)) + # avoid closing over `inds` as it doesn't fully infer and that makes it worse + # recomputing is cheap + inds_inner = ntuple(d -> d == dims ? p : (:), ndims(x)) + return (NoTangent(), ∇getindex(x, unthunk(dy), inds_inner...)) end + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) return x[inds...], sortslices_pullback end diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 052045d1e..00b08efbb 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -26,7 +26,7 @@ test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) - test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) + test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum)) @test_throws Exception sortslices(Diagonal(1:3), dims=1) end