Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
add tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 18, 2024
1 parent 7202dbc commit 39a4dea
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1"
Adapt = "4"
CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.16.0"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10, 11"
Expand Down
4 changes: 2 additions & 2 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Adapt: Adapt
using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using Compat: @compat

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
Expand All @@ -25,7 +26,6 @@ export get_device, get_device_type

export DeviceIterator

### uncomment below when min supported julia version is >=1.11
# public isleaf
@compat(public, (isleaf,))

end
14 changes: 14 additions & 0 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,19 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool
Returns `true` if `x` is a leaf node in the data structure.
Defining `MLDataDevices.isleaf(x::T) = true` for custom types
can be used to customize the behavior the data movement behavior
when an object with nested structure containing the type is transferred to a device.
`Adapt.adapt_structure(::AbstractDevice, x::T)` or
`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during
data movement if `isleaf(x::T) == true`.
If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`.
"""
isleaf(x) = Functors.isleaf(x)
19 changes: 19 additions & 0 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,22 @@ end
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end

@testset "isleaf" begin
# Functors.isleaf fallback
@test MLDataDevices.isleaf(rand(2))
@test !MLDataDevices.isleaf((rand(2),))

struct Tleaf
x
end
Functors.@functor Tleaf

MLDataDevices.isleaf(::Tleaf) = true

Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x))

cpu = cpu_device()
t = Tleaf(ones(2))
@test cpu(t).x == 2 .* ones(2)
end

0 comments on commit 39a4dea

Please sign in to comment.