Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: correctly handle adjoints of wrapped arrays (#90)
Browse files Browse the repository at this point in the history
* fix: correctly handle adjoints of wrapped arrays

* fix: use fast paths for adapt

* fix: adapt ranges to JuliaGPU/Adapt.jl#86
  • Loading branch information
avik-pal authored Oct 25, 2024
1 parent e9a2ed7 commit ece7ba2
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 38 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.1"
version = "1.4.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down Expand Up @@ -47,14 +46,13 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6, 1"
Adapt = "4"
Adapt = "4.1"
CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10, 11"
LinearAlgebra = "1.10"
MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Expand Down
21 changes: 12 additions & 9 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
module MLDataDevicesChainRulesCoreExt

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable

using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
dev isa UnknownDevice &&
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ)
return y, ∇adapt_storage_unknown
else
∇adapt_storage = let dev = dev, x = x
Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ)))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

end
1 change: 0 additions & 1 deletion src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using Compat: @compat
using LinearAlgebra: Transpose, Adjoint

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
Expand Down
16 changes: 5 additions & 11 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,10 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) :
map(D, x)
if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray
return Adapt.adapt(D, x)
end
return map(D, x)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
Expand Down Expand Up @@ -373,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
end
end

Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool
Expand All @@ -399,4 +393,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
isleaf(x) = Functors.isleaf(x)

isleaf(::AbstractArray{T}) where {T} = isbitstype(T)
isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
isleaf(::Adapt.WrappedArray) = false
4 changes: 2 additions & 2 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
4 changes: 2 additions & 2 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
22 changes: 17 additions & 5 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ end

@testset "CRC Tests" begin
dev = cpu_device() # Other devices don't work with FiniteDifferences.jl
test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true)
test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true)

gdev = gpu_device()
if !(gdev isa MetalDevice) # On intel devices causes problems
x = randn(10)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, gdev, x)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, gdev, x)
@test ∂dev === nothing
@test ∂x ones(10)

x = randn(10) |> gdev
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, cpu_device(), x)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, cpu_device(), x)
@test ∂dev === nothing
@test ∂x gdev(ones(10))
@test get_device(∂x) isa parameterless_type(typeof(gdev))
Expand Down Expand Up @@ -181,7 +181,6 @@ end
end

@testset "shared parameters" begin
# from
x = rand(1)
m = (; a=x, b=x')
count = Ref(0)
Expand All @@ -199,11 +198,24 @@ end
y::Float64
end

for x in [1.0, 'a', BitsType(1, 2.0)]
@testset for x in [1.0, 'a', BitsType(1, 2.0)]
@test MLDataDevices.isleaf([x])
@test !MLDataDevices.isleaf([x]')
@test !MLDataDevices.isleaf(transpose([x]))
@test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2)))
end
end
end

@testset "Zygote.gradient(wrapped arrays)" begin
using Zygote

x = rand(4, 4)
cdev = cpu_device()

@test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64}

gdev = gpu_device()

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end
4 changes: 2 additions & 2 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand Down Expand Up @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down

2 comments on commit ece7ba2

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118090

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.2 -m "<description of version>" ece7ba2b33564dd4705f573da9cf382f6231f09c
git push origin v1.4.2

Please sign in to comment.