diff --git a/Project.toml b/Project.toml index e44739e88..b66176754 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.43.2" +version = "1.44.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] +Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.42.0, 4" @@ -30,7 +32,6 @@ StructArrays = "0.6.11" julia = "1.6" [extras] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index b314d7be7..30e492d2e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,10 +1,11 @@ module ChainRules +using Adapt: adapt using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat using Distributed -using GPUArraysCore: AbstractGPUArrayStyle +using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 36ba52d55..49a4e1ac2 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -515,64 +515,14 @@ for findm in (:findmin, :findmax) @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) y, ind = $findm(x; dims=dims) - project = ProjectTo(x) - # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind)) - x_ithunk = InplaceableThunk(x_thunk) do dx - if dims isa Colon - view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy)) - else - view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0 - end - dx - end - return (NoTangent(), x_ithunk) + return (NoTangent(), thunked_∇getindex(x, dy, ind),) end return (y, ind), $findm_pullback end end -# This function is roughly `setindex!(zero(x), dy, inds...)`: - -function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...) - _zero_fill = eltype(dy) == Any ? 0 : zero(eltype(dy)) - - # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't - # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), _zero_fill) - view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray - dx -end -function _zerolike_writeat(x::AbstractArray, dy, dims, inds...) - # Since we have `x`, we can also handle arrays of arrays. - dx = map(zero, x) - if dims isa Colon - view(dx, inds...) .= Ref(dy) - else - view(dx, inds...) .= dy - end - dx -end - -# Allow for second derivatives, by writing rules for `_zerolike_writeat`; -# these rules are the reason it takes a `dims` argument. - -function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...) - return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...) -end - -function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) - z = _zerolike_writeat(x, dy, dims, inds...) - function _zerolike_writeat_pullback(dz) - dx = sum(view(unthunk(dz), inds...); dims=dims) - nots = map(_ -> NoTangent(), inds) - return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) - end - return z, _zerolike_writeat_pullback -end - # These rules for `maximum` pick the same subgradient as `findmax`: function frule((_, ẋ), ::typeof(maximum), x; dims=:) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c10ba6e71..d486fd67e 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) return y, map_pullback end + +##### +##### `task_local_storage` +##### + +# Called by `@allowscalar` from GPUArrays + +ChainRules.@non_differentiable task_local_storage(key::Any) +ChainRules.@non_differentiable task_local_storage(key::Any, value::Any) + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value) + y, back = task_local_storage(key, value) do + rrule_via_ad(config, body) + end + function task_local_storage_pullback(dy) + dbody = only(back(dy)) + return (NoTangent(), dbody, NoTangent(), NoTangent()) + end + return y, task_local_storage_pullback +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index fbce009e0..76ab59232 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon) return x, getindex_back_4 end - ##### -##### getindex +##### getindex(::AbstractArray) ##### function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) return x[inds...], ẋ[inds...] end -function rrule(::typeof(getindex), x::Array{<:Number}, inds...) - # removes any logical indexing, CartesianIndex etc - # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int +function rrule(::typeof(getindex), x::AbstractArray, inds...) + function getindex_pullback(dy) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + end + return x[inds...], getindex_pullback +end + +function thunked_∇getindex(x, dy, inds...) + return InplaceableThunk( + dx -> ∇getindex!(dx, unthunk(dy), Base.to_indices(x, inds)...), + @thunk(∇getindex(x, unthunk(dy), inds...)), + ) +end + +""" + ∇getindex(x, dy, inds...) + +For the `rrule` of `y = x[inds...]`, this function is roughly +`setindex(zero(x), dy, inds...)`, returning the array `dx`. +Differentiable. Includes `ProjectTo(x)(dx)`. +""" +function ∇getindex(x::AbstractArray, dy, inds...) + # `to_indices` removes any logical indexing, colons, CartesianIndex etc, + # leaving just Int / AbstractVector of Int plain_inds = Base.to_indices(x, inds) - y = getindex(x, plain_inds...) - function getindex_pullback(ȳ) - function getindex_add!(Δ) - # this a optimizes away for simple cases - for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...)) - Δ[ii...] += ȳ_ii - end - return Δ - end + dx = _setindex_zero(x, dy, plain_inds...) + ∇getindex!(dx, dy, plain_inds...) + return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules +end + +""" + _setindex_zero(x, dy, inds...) - x̄ = InplaceableThunk( - getindex_add!, - @thunk(getindex_add!(zero(x))), - ) - īnds = broadcast(Returns(NoTangent()), inds) - return (NoTangent(), x̄, īnds...) +This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`, +and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what +`∇getindex` does next. + +It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't +allow `eltype(dy)`, nor does it work for many structured matrices. +""" +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), ZeroTangent()) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), ZeroTangent()) +function _setindex_zero(x::AbstractArray, dy, inds::Integer...) + # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), + # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors + T = Union{typeof(dy), ZeroTangent} + return fill!(similar(x, T, axes(x)), ZeroTangent()) +end +function _setindex_zero(x::AbstractArray, dy, inds...) + T = Union{eltype(dy), ZeroTangent} + return fill!(similar(x, T, axes(x)), ZeroTangent()) +end +ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) + +function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractArray, dy, inds...) + view(dx, inds...) .+= dy + return dx +end + +# Allow for second derivatives, by writing rules for `∇getindex`: + +function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...) + return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...) +end + +function rrule(::typeof(∇getindex), x, dy, inds...) + z = ∇getindex(x, dy, inds...) + function ∇getindex_pullback(dz) + d2y = getindex(unthunk(dz), inds...) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...) end + return z, ∇getindex_pullback +end + +# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. +# To avoid this, copy everything back to the CPU. +# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: - return y, getindex_pullback +function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) + view(dx, inds...) .+= dy + return dx +end +function ∇getindex!(dx::AbstractGPUArray, dy, inds...) + dx_cpu = adapt(Array, dx) + view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) + copyto!(dx, dx_cpu) + return dx end ##### @@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) return view(x, inds...), view(ẋ, inds...) end +function rrule(::typeof(view), x::AbstractArray, inds...) + function view_pullback(dy) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) + end + return view(x, inds...), view_pullback +end + +function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...) + # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex: + function view_pullback_0(dy) + nots = map(Returns(NoTangent()), (i, jkl...)) + return (NoTangent(), thunked_∇getindex(x, dy, i:i, jkl...), nots...) + end + return view(x, i, jkl...), view_pullback_0 +end + ##### ##### setindex! ##### @@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) end +##### +##### unsafe_getindex +##### + +# This is called by e.g. `iterate(1:0.1:2)`, +# and fixes https://github.com/FluxML/Zygote.jl/issues/1247 +# Only needs to accept AbstractRange, but AbstractVector makes testing easier. + +function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return Base.unsafe_getindex(x, i), getindex(ẋ, i) +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return rrule_via_ad(cfg, getindex, x, i) +end ##### ##### `eachslice` and friends diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 42f674a24..0805da91f 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -62,10 +62,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) function sortslices_pullback(dy) - # No actual need to zero this, and if you didn't, then you could widen eltype - # Also, you could use similar(dy) here not x, same size? - dx = _zerolike_writeat(x, unthunk(dy), (), inds...) - return (NoTangent(), ProjectTo(x)(dx)) + return (NoTangent(), ∇getindex(x, unthunk(dy), inds...)) end return x[inds...], sortslices_pullback end @@ -94,12 +91,11 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) keep = map(I -> I[1], findall(mask)) if dims isa Colon - # The function `_zerolike_writeat` allows second derivatives. - # Should perhaps eventually be shared with `getindex`. - dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) + # The function `∇getindex` allows second derivatives. + dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec! else inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) - dx = _zerolike_writeat(x, dy, (), inds...) + dx = ∇getindex(x, dy, inds...) end return (NoTangent(), ProjectTo(x)(dx)) end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index dced0eb5e..075c5e050 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -366,14 +366,7 @@ end @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) - - # Second derivatives - test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) - y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) - @test y == [0 0; 5 5] - @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) + test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),)) end @testset "$imum" for imum in [maximum, minimum] diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 29a7e2337..a0748276c 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -77,6 +77,95 @@ test_rrule(getindex, x, [2,2], [3,3]) end end + + @testset "getindex for structured arrays" begin + test_frule(getindex, Diagonal(rand(3)), 1) + test_frule(getindex, Symmetric(rand(3, 3)), 2, 3) + + test_rrule(getindex, Diagonal(rand(3)), 1) + @test_skip test_rrule(getindex, Diagonal(rand(3)), 2, :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/260 + dgrad = rrule(getindex, Diagonal(rand(3)), 2, :)[2]([1,2,3])[2] + @test unthunk(dgrad) ≈ Diagonal([0, 2, 0]) + + test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2) + sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2] + @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] + end + + @testset "getindex(::Array{<:Array})" begin + test_frule(getindex, [rand(2) for _ in 1:3], 1) + test_frule(getindex, [rand(2) for _ in 1:3], 2:3) + test_frule(getindex, [rand(2) for _ in 1:3], [true, false, true]) + + test_rrule(getindex, [rand(2) for _ in 1:3], 1; check_inferred=false) + test_rrule(getindex, [rand(2) for _ in 1:3], 2:3; check_inferred=false) + test_frule(getindex, [rand(2) for _ in 1:3], [true, false, true]; check_inferred=false) + end + + @testset "getindex(::Array{<:Weird})" begin + xfix = [Base.Fix1(*, pi), Base.Fix1(^, ℯ), Base.Fix1(/, -1)] + dxfix = [Tangent{Base.Fix1}(; x = i/10) for i in 1:3] + # test_frule(getindex, xfix ⊢ dxfix, 1) + # test_rrule(getindex, xfix ⊢ dxfix, 1) + + dx1 = unthunk(rrule(getindex, xfix, 1)[2](dxfix[1])[2]) + @test dx1[1] == dxfix[1] + @test iszero(dx1[2]) + + dx23 = unthunk(rrule(getindex, xfix, 2:3)[2](dxfix[2:3])[2]) + @test iszero(dx23[1]) + @test dx23[3] == dxfix[3] + end + + @testset "second derivatives: ∇getindex" begin + @eval using ChainRules: ∇getindex + # Forward, scalar result + test_frule(∇getindex, rand(2, 3), rand(), 3) + test_frule(∇getindex, rand(2, 3), rand()+im, 2, 1) + # array result + test_frule(∇getindex, rand(2, 3), rand(2), 4:5) + test_frule(∇getindex, rand(2, 3), rand(3), 1, :) + test_frule(∇getindex, rand(2, 3), rand(1, 2), [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) + test_frule(∇getindex, rand(2, 3), rand(3), Bool[1 0 1; 0 1 0]) + # arrays of arrays + test_frule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3) + test_frule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2) + + # Reverse, scalar result + test_rrule(∇getindex, rand(2, 3), rand(), 3; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand()+im, 2, 1; check_inferred=false) + # array result + test_rrule(∇getindex, rand(2, 3), rand(2), 4:5; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(3), 1, :; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(1, 2), [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent(); check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(3), Bool[1 0 1; 0 1 0]; check_inferred=false) + # arrays of arrays + test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false) + test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) + end + + @testset "getindex(::AbstractGPUArray)" begin + x_23_gpu = jl(rand(2, 3)) # using JLArrays, loaded for @gpu in test_helpers.jl + + # Scalar indexing, copied from: @macroexpand @allowscalar A[i] + @test_skip begin # This gives + y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + @test y1 == @allowscalar x_23_gpu[1] + end + @test_skip begin + bk1(1.0) # This gives a StackOverflowError! Or gives zero in global scope. + true + end + # But this works, and calls the rule: + # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] + + y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+= + @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) + + y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU + @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] + @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) + end end @testset "first & tail" begin @@ -100,6 +189,10 @@ end test_frule(view, rand(3, 4), :, 1) test_frule(view, rand(3, 4), 2, [1, 1, 2]) test_frule(view, rand(3, 4), 3, 4) + + test_rrule(view, rand(3, 4), :, 1) + test_rrule(view, rand(3, 4), 2, [1, 1, 2]) + test_rrule(view, rand(3, 4), 3, 4) end @testset "setindex!" begin @@ -107,6 +200,12 @@ end test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) end +@testset "unsafe_getindex" begin + # In real life this is called only on some AbstractRanges, but easier to test on Array: + test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) + test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) +end + @testset "eachslice" begin # Testing eachrow not collect∘eachrow leads to errors, e.g. # test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195 diff --git a/test/runtests.jl b/test/runtests.jl index 9ac5c5981..71444f388 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize +using Adapt using Base.Broadcast: broadcastable using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm using FiniteDifferences +using GPUArraysCore +using JLArrays using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot