From feb845cdbda4f870e29f9d7190680b0458107fc0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 22 Oct 2024 14:36:41 +0200 Subject: [PATCH] Preserve ranges. --- Project.toml | 2 +- src/base.jl | 20 ++++++++++++++++++++ test/runtests.jl | 12 ++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c8ceaf1..3f56ceb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Adapt" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/base.jl b/src/base.jl index 541c834..11ae26f 100644 --- a/src/base.jl +++ b/src/base.jl @@ -16,6 +16,7 @@ _adapt_tuple_structure(to, xs::Tuple) = _adapt_tuple_structure(to, xs::Tuple{}) = () _adapt_tuple_structure(to, xs::Tuple{<:Any}) = (adapt(to, first(xs)), ) + ## Closures # two things can be captured: static parameters, and actual values (fields) @@ -63,3 +64,22 @@ adapt_structure(to, bc::Broadcasted{Style}) where Style = adapt_structure(to, ex::Extruded) = Extruded(adapt(to, ex.x), ex.keeps, ex.defaults) + + +## Ranges + +adapt_structure(to, r::UnitRange) = + UnitRange(adapt(to, r.start), adapt(to, r.stop)) + +adapt_structure(to, r::Base.OneTo) = Base.OneTo(adapt(to, r.stop)) + +adapt_structure(to, r::StepRange) = + StepRange(adapt(to, r.start), adapt(to, r.step), adapt(to, r.stop)) + +adapt_structure(to, r::StepRangeLen) = + StepRangeLen(adapt(to, r.ref), adapt(to, r.step), r.len, r.offset) + +adapt_structure(to, r::Base.Slice) = Base.Slice(adapt(to, r.indices)) + +adapt_structure(to, r::LinRange) = + LinRange(adapt(to, r.start), adapt(to, r.stop), r.len) diff --git a/test/runtests.jl b/test/runtests.jl index 55e913c..27d9031 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -232,3 +232,15 @@ end #@test_adapt SArray [1,2,3] SArray{Tuple{3}}([1,2,3]) @test adapt(SArray, [1,2,3]) === SArray{Tuple{3}}([1,2,3]) end + +@testset "Ranges" begin + # normally these fall back to `convert(Array, r)`, so we only need to test + # that the type matches + + @test adapt(Array, 1:10) === 1:10 + @test adapt(Array, Base.OneTo(10)) === Base.OneTo(10) + @test adapt(Array, 1:2:10) === 1:2:10 + @test adapt(Array, 1.:2.:10.) === 1.:2.:10. + @test adapt(Array, Base.Slice(1:10)) === Base.Slice(1:10) + @test adapt(Array, LinRange(1,2,10)) === LinRange(1,2,10) +end