diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index 85e522bbb..da572377e 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -20,7 +20,7 @@ Lux = "1" LuxCUDA = "0.3" MLDatasets = "0.7" MLUtils = "0.4" -OneHotArrays = "0.2" +OneHotArrays = "0.2.5" Optimisers = "0.3.3, 0.4" Setfield = "1" Statistics = "1" diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index 60049f080..e9aa48aa6 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -20,7 +20,7 @@ Lux = "1" LuxCUDA = "0.3" MLDatasets = "0.7" MLUtils = "0.4" -OneHotArrays = "0.2" +OneHotArrays = "0.2.5" Optimisers = "0.3.3, 0.4" OrdinaryDiffEqTsit5 = "1" SciMLSensitivity = "7.63" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 3cc272fd3..49b955621 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.5.0" +version = "1.5.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -18,6 +18,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -35,6 +36,7 @@ MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" MLDataDevicesMLUtilsExt = "MLUtils" MLDataDevicesMetalExt = ["GPUArrays", "Metal"] +MLDataDevicesOneHotArraysExt = "OneHotArrays" MLDataDevicesReactantExt = "Reactant" MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" MLDataDevicesReverseDiffExt = "ReverseDiff" @@ -55,6 +57,7 @@ Functors = "0.4.8" GPUArrays = "10, 11" MLUtils = "0.4.4" Metal = "1" +OneHotArrays = "0.2.5" Preferences = "1.4" Random = "1.10" Reactant = "0.2.4" diff --git a/lib/MLDataDevices/ext/MLDataDevicesOneHotArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesOneHotArraysExt.jl new file mode 100644 index 000000000..ceb6d6bde --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesOneHotArraysExt.jl @@ -0,0 +1,17 @@ +module MLDataDevicesOneHotArraysExt + +using Adapt: Adapt +using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice +using OneHotArrays: OneHotArray + +for op in (:get_device, :get_device_type) + @eval Internal.$(op)(x::OneHotArray) = Internal.$(op)(x.indices) +end + +# Reactant doesn't pay very nicely with OneHotArrays at the moment +function Adapt.adapt_structure(dev::ReactantDevice, x::OneHotArray) + x_cpu = Adapt.adapt_structure(CPUDevice(), x) + return Adapt.adapt_storage(dev, convert(Array, x_cpu)) +end + +end diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 9914e0f57..1fb732d37 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -9,6 +9,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -30,6 +31,7 @@ FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" MLUtils = "0.4" +OneHotArrays = "0.2.5" Pkg = "1.10" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 28275d3b7..5ece810bf 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -5,6 +5,8 @@ using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools using Functors: Functors +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + @testset "Issues Patches" begin @testset "#10 patch" begin dev = CPUDevice() @@ -219,3 +221,23 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end + +@testset "OneHotArrays" begin + using OneHotArrays + + x = onehotbatch("abracadabra", 'a':'e', 'e') + @test get_device(x) isa CPUDevice + + gdev = gpu_device() + x_g = gdev(x) + @test get_device(x_g) isa parameterless_type(typeof(gdev)) + + if BACKEND_GROUP == "none" || BACKEND_GROUP == "reactant" + using Reactant + + rdev = reactant_device() + x_rd = rdev(x) + @test get_device(x_rd) isa ReactantDevice + @test x_rd isa Reactant.ConcreteRArray{Bool, 2} + end +end