diff --git a/.buildkite/testing_luxtestutils.yml b/.buildkite/testing_luxtestutils.yml index 32434dd2e2..8d5301c3b5 100644 --- a/.buildkite/testing_luxtestutils.yml +++ b/.buildkite/testing_luxtestutils.yml @@ -9,12 +9,22 @@ steps: codecov: true dirs: - lib/LuxTestUtils/src + - lib/MLDataDevices/src + - lib/MLDataDevices/ext command: | julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils -e ' import Pkg; Pkg.Registry.update(); - Pkg.instantiate(); - Pkg.test(; coverage="user")' + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/MLDataDevices",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs); + Pkg.instantiate()' + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils/test -e ' + import Pkg, LuxTestUtils + dir = dirname(pathof(LuxTestUtils)) + include(joinpath(dir, "../test/runtests.jl"))' agents: queue: "juliagpu" cuda: "*" diff --git a/.github/workflows/CI_LuxTestUtils.yml b/.github/workflows/CI_LuxTestUtils.yml index 51e6f1ee20..ccd9c5db56 100644 --- a/.github/workflows/CI_LuxTestUtils.yml +++ b/.github/workflows/CI_LuxTestUtils.yml @@ -45,13 +45,18 @@ jobs: - name: "Install Dependencies and Run Tests" run: | import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/MLDataDevices",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) Pkg.Registry.update() Pkg.instantiate() Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxTestUtils/src + directories: lib/LuxTestUtils/src,lib/MLDataDevices/src,lib/MLDataDevices/ext - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -75,6 +80,11 @@ jobs: - name: "Install Dependencies and Run Tests" run: | import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/MLDataDevices",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) Pkg.Registry.update() Pkg.instantiate() Pkg.test(; coverage="user") diff --git a/Project.toml b/Project.toml index 140fb7cf31..59ae0f0b3e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.2.3" +version = "1.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -66,14 +66,14 @@ LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" [compat] -ADTypes = "1.8.1" +ADTypes = "1.10" Adapt = "4.1" ArgCheck = "2.3" ArrayInterface = "7.10" CUDA = "5.3.2" ChainRulesCore = "1.24" -Compat = "4.15" -ComponentArrays = "0.15.16" +Compat = "4.16" +ComponentArrays = "0.15.18" ConcreteStructs = "0.2.3" DispatchDoctor = "0.4.12" Enzyme = "0.13.13" @@ -82,20 +82,20 @@ FastClosures = "0.3.2" Flux = "0.14.25" ForwardDiff = "0.10.36" FunctionWrappers = "1.1.3" -Functors = "0.4.12" +Functors = "0.5" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" LossFunctions = "0.11.1" -LuxCore = "1" +LuxCore = "1.2" LuxLib = "1.3.7" -MLDataDevices = "1.5" +MLDataDevices = "1.6" MLUtils = "0.4.4" MPI = "0.20.19" MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" NNlib = "0.9.24" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Preferences = "1.4.3" Random = "1.10" Reactant = "0.2.4" @@ -107,7 +107,7 @@ SimpleChains = "0.4.7" Static = "1.1.1" StaticArraysCore = "1.4.3" Statistics = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" WeightInitializers = "1" Zygote = "0.6.70" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 7ef3bb514b..6655cf829c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -30,27 +30,27 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.3" +ADTypes = "1.10" Adapt = "4" ChainRulesCore = "1.24" -ComponentArrays = "0.15" +ComponentArrays = "0.15.18" Documenter = "1.4" DocumenterVitepress = "0.1.3" Enzyme = "0.13.13" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" -Functors = "0.4.12" +Functors = "0.5" GPUArraysCore = "0.1, 0.2" KernelAbstractions = "0.9" LinearAlgebra = "1.10" Literate = "2.18.0" Lux = "1" LuxCUDA = "0.3.2" -LuxCore = "1" +LuxCore = "1.2" LuxLib = "1.3.4" LuxTestUtils = "1.5" -MLDataDevices = "1.4" -Optimisers = "0.3.3, 0.4" +MLDataDevices = "1.6" +Optimisers = "0.3.4, 0.4" Pkg = "1.10" Printf = "1.10" Random = "1.10" diff --git a/docs/src/manual/nn_inside_gpu_kernels.md b/docs/src/manual/nn_inside_gpu_kernels.md index 382fad1338..b109c24650 100644 --- a/docs/src/manual/nn_inside_gpu_kernels.md +++ b/docs/src/manual/nn_inside_gpu_kernels.md @@ -21,7 +21,7 @@ making it compatible with multiple GPU backends. input data and let Lux handle the batching internally. ```@example nn_in_gpu_kernels -using Lux, LuxCUDA, Random +using Lux, LuxCUDA, Random, Functors using KernelAbstractions, StaticArrays ``` @@ -45,8 +45,8 @@ nn = Chain(Dense(4, 4, relu), Dense(4, 4)) ps, st = Lux.setup(Xoshiro(123), nn) to_sarray(x) = SArray{Tuple{size(x)...}}(x) -ps_static = Lux.recursive_map(to_sarray, ps) -st_static = Lux.recursive_map(to_sarray, st) +ps_static = fmap(to_sarray, ps) +st_static = fmap(to_sarray, st) ``` First we will run it on CPU. diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index 4e44d404cf..1e32e0ea08 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -9,9 +9,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.15" +ComponentArrays = "0.15.18" ForwardDiff = "0.10" Lux = "1" LuxCUDA = "0.3" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Zygote = "0.6" diff --git a/examples/BayesianNN/Project.toml b/examples/BayesianNN/Project.toml index 87f4f34339..1ec9aa9620 100644 --- a/examples/BayesianNN/Project.toml +++ b/examples/BayesianNN/Project.toml @@ -10,10 +10,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CairoMakie = "0.12" -Functors = "0.4" +Functors = "0.5" LinearAlgebra = "1" Lux = "1" Random = "1" -Tracker = "0.2" +Tracker = "0.2.36" Turing = "0.34, 0.35" Zygote = "0.6.69" diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index 1604f96bac..0d28a37f46 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -12,7 +12,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" [compat] CairoMakie = "0.12" -ComponentArrays = "0.15" +ComponentArrays = "0.15.18" LineSearches = "7" Lux = "1" Optimization = "4" diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index da572377e3..501036806c 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -14,14 +14,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1" -ComponentArrays = "0.15" +ADTypes = "1.10" +ComponentArrays = "0.15.18" Lux = "1" LuxCUDA = "0.3" MLDatasets = "0.7" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Setfield = "1" Statistics = "1" Zygote = "0.6" diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml index 792a1341c9..5a0c022559 100644 --- a/examples/ImageNet/Project.toml +++ b/examples/ImageNet/Project.toml @@ -33,12 +33,12 @@ ImageMagick = "1" JLD2 = "0.5.1" Lux = "1" LuxCUDA = "0.3.3" -MLDataDevices = "1.3" +MLDataDevices = "1.6" MLUtils = "0.4.4" MPI = "0.20.21" NCCL = "0.1.1" OneHotArrays = "0.2.5" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" ParameterSchedulers = "0.4.2" Random = "1.10" Setfield = "1.1.1" diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index e9aa48aa67..9bb2456cc1 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -15,13 +15,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.15" +ComponentArrays = "0.15.18" Lux = "1" LuxCUDA = "0.3" MLDatasets = "0.7" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" OrdinaryDiffEqTsit5 = "1" SciMLSensitivity = "7.63" Statistics = "1" diff --git a/examples/OptimizationIntegration/Project.toml b/examples/OptimizationIntegration/Project.toml index 7c6e88c080..e43832d037 100644 --- a/examples/OptimizationIntegration/Project.toml +++ b/examples/OptimizationIntegration/Project.toml @@ -14,7 +14,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" [compat] CairoMakie = "0.12.10" -ComponentArrays = "0.15.17" +ComponentArrays = "0.15.18" Lux = "1" LuxCUDA = "0.3.3" MLUtils = "0.4.4" diff --git a/examples/PINN2DPDE/Project.toml b/examples/PINN2DPDE/Project.toml index 03e427a64f..57abc282ed 100644 --- a/examples/PINN2DPDE/Project.toml +++ b/examples/PINN2DPDE/Project.toml @@ -12,13 +12,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.8.1" +ADTypes = "1.10" CairoMakie = "0.12.10" Lux = "1" LuxCUDA = "0.3.3" MLUtils = "0.4.4" OnlineStats = "1.7.1" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Printf = "1.10" Random = "1.10" Statistics = "1.10" diff --git a/examples/PolynomialFitting/Project.toml b/examples/PolynomialFitting/Project.toml index 5cf0394ef1..36c4d4e59a 100644 --- a/examples/PolynomialFitting/Project.toml +++ b/examples/PolynomialFitting/Project.toml @@ -10,10 +10,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1" +ADTypes = "1.10" CairoMakie = "0.12" Lux = "1" LuxCUDA = "0.3" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Statistics = "1" Zygote = "0.6" diff --git a/examples/SimpleChains/Project.toml b/examples/SimpleChains/Project.toml index 33304a7dce..c8ec5a59bc 100644 --- a/examples/SimpleChains/Project.toml +++ b/examples/SimpleChains/Project.toml @@ -11,12 +11,12 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1" +ADTypes = "1.10" Lux = "1" MLDatasets = "0.7.14" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Random = "1" SimpleChains = "0.4.6" Zygote = "0.6.69" diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index 02dff511c1..78cbad4b65 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -11,11 +11,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1" +ADTypes = "1.10" JLD2 = "0.5" Lux = "1" LuxCUDA = "0.3" MLUtils = "0.4" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Statistics = "1" Zygote = "0.6" diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 0174f39723..1c6b5cc887 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -3,11 +3,13 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme using Enzyme: Enzyme, Active, Const, Duplicated using EnzymeCore: EnzymeCore +using Functors: fmap using Setfield: @set! using Static: False, True -using Lux: Lux +using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState +using MLDataDevices: isleaf include("training.jl") diff --git a/ext/LuxEnzymeExt/training.jl b/ext/LuxEnzymeExt/training.jl index 3718379bf1..79c950a4b5 100644 --- a/ext/LuxEnzymeExt/training.jl +++ b/ext/LuxEnzymeExt/training.jl @@ -1,6 +1,6 @@ function Lux.Training.compute_gradients_impl( ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} - dps = Lux.recursive_make_zero(ts.parameters) + dps = Lux.Training.dparameters(ts.cache) obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function( obj_fn, ts.model, ts.parameters, ts.states, data, True()) @@ -22,8 +22,7 @@ const AUTODIFF_CACHE_TYPE = TrainingBackendCache{ function Lux.Training.compute_gradients_impl( ::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F} - Enzyme.make_zero!(ts.cache.dparameters) - dps = ts.cache.dparameters + dps = Lux.Training.dparameters(ts.cache) _, loss = Enzyme.autodiff( EnzymeCore.ReverseWithPrimal, Const(ts.cache.extras.obj_fn), Active, @@ -57,14 +56,16 @@ const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{ function Lux.Training.compute_gradients_impl(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) + dps = Lux.Training.dparameters(ts.cache) params = Duplicated(ts.parameters, dps) tape, (loss, st_, stats), _ = ts.cache.extras.forward( Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data)) ts.cache.extras.reverse( Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data), - (one(loss), Lux.recursive_make_zero(st_), Lux.recursive_make_zero(stats)), tape) + (one(loss), fmap(Utils.zero, st_; exclude=isleaf), + fmap(Utils.zero, stats; exclude=isleaf)), tape + ) @set! ts.objective_function = obj_fn @set! ts.states = st_ diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index bf6f9c6675..d2bd60abbd 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -3,6 +3,7 @@ module LuxReverseDiffExt using ADTypes: ADTypes, AbstractADType, AutoReverseDiff using ArrayInterface: ArrayInterface using FunctionWrappers: FunctionWrapper +using Functors: fmap using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules using Setfield: @set! @@ -11,7 +12,7 @@ using Static: False, True using Lux: Lux, Utils using Lux.Training: Training, TrainingBackendCache, TrainState using LuxCore: LuxCore -using MLDataDevices: CPUDevice +using MLDataDevices: CPUDevice, isleaf include("utils.jl") include("rules.jl") diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 6deaf63788..f7b626d267 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -2,7 +2,8 @@ function Lux.Training.compute_gradients_impl( ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F} @set! ts.cache = TrainingBackendCache( - ad, True(), Lux.recursive_make_zero(ts.parameters), nothing) + ad, True(), fmap(Utils.zero, ts.parameters; exclude=isleaf), nothing + ) @set! ts.objective_function = obj_fn return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end @@ -11,7 +12,9 @@ function Lux.Training.compute_gradients_impl(::AutoReverseDiff{false}, obj_fn::F ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F} dparams = Training.dparameters(ts.cache) tape = ReverseDiff.InstructionTape() - ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ts.parameters, dparams) + ps_tracked = fmap( + Utils.Fix3(TrackedArray, tape), ts.parameters, dparams; exclude=isleaf + ) loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) loss.deriv = true @@ -27,8 +30,9 @@ end function Lux.Training.compute_gradients_impl( ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F} @set! ts.cache = TrainingBackendCache( - ad, True(), Lux.recursive_make_zero(ts.parameters), - (; data_cache=deepcopy(data), ps_cache=deepcopy(ts.parameters))) + ad, True(), fmap(Utils.zero, ts.parameters; exclude=isleaf), + (; data_cache=deepcopy(data), ps_cache=deepcopy(ts.parameters)) + ) @set! ts.objective_function = nothing return Lux.Training.compute_gradients(ad, obj_fn, data, ts) @@ -60,12 +64,14 @@ function Lux.Training.compute_gradients_impl(ad::AutoReverseDiff{true}, obj_fn:: (; ps_cache, data_cache) = ts.cache.extras if !first_try - Lux.recursive_copyto!(ps_cache, ts.parameters) - Lux.recursive_copyto!(data_cache, data) + fmap(copyto!, ps_cache, ts.parameters; exclude=isleaf) + fmap(copyto!, data_cache, data; exclude=isleaf) end tape = ReverseDiff.InstructionTape() - ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ps_cache, dparams) + ps_tracked = fmap( + Utils.Fix3(TrackedArray, tape), ps_cache, dparams; exclude=isleaf + ) loss = first(obj_fn(ts.model, ps_tracked, ts.states, data_cache)) loss.deriv = true @@ -86,9 +92,9 @@ function Lux.Training.compute_gradients_impl(::AutoReverseDiff{true}, obj_fn::F, ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F} (; ps_cache, data_cache, output) = ts.cache.extras - dparams = Lux.recursive_make_zero!!(ts.cache.dparameters) - Lux.recursive_copyto!(ps_cache, ts.parameters) - Lux.recursive_copyto!(data_cache, data) + dparams = Training.dparameters(ts.cache) + fmap(copyto!, ps_cache, ts.parameters; exclude=isleaf) + fmap(copyto!, data_cache, data; exclude=isleaf) for wrapper in ts.cache.extras.forward_executor wrapper() diff --git a/ext/LuxTrackerExt/LuxTrackerExt.jl b/ext/LuxTrackerExt/LuxTrackerExt.jl index e243611d29..82765ac28d 100644 --- a/ext/LuxTrackerExt/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt/LuxTrackerExt.jl @@ -3,12 +3,14 @@ module LuxTrackerExt using ADTypes: AbstractADType, AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore +using Functors: fmap using Setfield: @set! using Static: False, True using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules using Lux: Lux, Utils using Lux.Training: Training, TrainingBackendCache, TrainState +using MLDataDevices: isleaf const CRC = ChainRulesCore diff --git a/ext/LuxTrackerExt/training.jl b/ext/LuxTrackerExt/training.jl index 982d708f94..c0da2a836b 100644 --- a/ext/LuxTrackerExt/training.jl +++ b/ext/LuxTrackerExt/training.jl @@ -15,7 +15,7 @@ end function Lux.Training.compute_gradients_impl( ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} - grads = Lux.recursive_make_zero(ts.parameters) + grads = fmap(Utils.zero, ts.parameters; exclude=isleaf) cache = TrainingBackendCache(ad, True(), grads, nothing) @set! ts.cache = cache @set! ts.objective_function = obj_fn diff --git a/ext/LuxTrackerExt/utils.jl b/ext/LuxTrackerExt/utils.jl index 79bead6df0..81b020b03b 100644 --- a/ext/LuxTrackerExt/utils.jl +++ b/ext/LuxTrackerExt/utils.jl @@ -6,8 +6,9 @@ Utils.gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Utils.gate(h, n)] Utils.gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Utils.gate(h, n), :] function construct_tracked_params(ps, dps) - map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) - return Lux.recursive_map(map_fn, ps, dps) + return fmap(ps, dps; exclude=isleaf) do p, dp + Tracker.TrackedArray(Tracker.Call(), p, dp) + end end Utils.eltype(::Type{<:TrackedReal{T}}) where {T} = T diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index eec8280b54..6d154cd2d3 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.1.1" +version = "1.2.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -32,14 +32,14 @@ LuxCoreSetfieldExt = "Setfield" [compat] ArrayInterface = "7.9" ChainRulesCore = "1.24" -Compat = "4.15.0" +Compat = "4.16" DispatchDoctor = "0.4.10" EnzymeCore = "0.8.5" -Functors = "0.4.12" -MLDataDevices = "1" +Functors = "0.5" +MLDataDevices = "1.6" Random = "1.10" Reactant = "0.2.4" ReverseDiff = "1.15" Setfield = "1" -Tracker = "0.2.34" +Tracker = "0.2.36" julia = "1.10" diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index d97ed31096..4d91582572 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -12,8 +12,12 @@ function LuxCore.Internal.fmap_with_path_impl(args...; kwargs...) end LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...) -function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, - x) where {layers} +function Functors.functor(::Type{<:LuxCore.AbstractLuxLayer}, x) + return Functors.NoChildren(), Returns(x) +end + +function Functors.functor( + ::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers} children = NamedTuple{layers}(getproperty.((x,), layers)) layer_reconstructor = let x = x, layers = layers z -> reduce(LuxCore.Internal.setfield, zip(layers, z); init=x) @@ -21,8 +25,7 @@ function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, return children, layer_reconstructor end -function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, - x) where {layer} +function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, x) where {layer} children = NamedTuple{(layer,)}((getproperty(x, layer),)) layer_reconstructor = let x = x, layer = layer z -> LuxCore.Internal.setfield(x, layer, getproperty(z, layer)) diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml index 1088992bad..1d84c918ea 100644 --- a/lib/LuxCore/test/Project.toml +++ b/lib/LuxCore/test/Project.toml @@ -13,8 +13,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Aqua = "0.8.7" EnzymeCore = "0.8.5" ExplicitImports = "1.9.0" -Functors = "0.4.12" -MLDataDevices = "1.0.0" -Optimisers = "0.3.3, 0.4" +Functors = "0.5" +MLDataDevices = "1.6" +Optimisers = "0.3.4, 0.4" Random = "1.10" Test = "1.10" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index f55dba7997..1127e28958 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -334,8 +334,11 @@ end end @testset "Convenience Checks" begin - models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), - Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]] + models1 = [ + Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), + [Dense(5, 10), Dense(10, 5)] + ] @test LuxCore.contains_lux_layer(models1) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ac6832f191..481872491d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.7" +version = "1.3.8" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -62,7 +62,7 @@ ArrayInterface = "7.9" BLISBLAS = "0.1" CUDA = "5.3.2" ChainRulesCore = "1.24" -Compat = "4.15.0" +Compat = "4.16" CpuId = "0.3" DispatchDoctor = "0.4.12" Enzyme = "0.13.13" @@ -73,9 +73,9 @@ Hwloc = "3.2" KernelAbstractions = "0.9.27" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" -LuxCore = "1" +LuxCore = "1.2" MKL = "0.7" -MLDataDevices = "1.5" +MLDataDevices = "1.6" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" @@ -88,6 +88,6 @@ SLEEFPirates = "0.6.43" Static = "0.8.4, 1" StaticArraysCore = "1.4.3" Statistics = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 1277656e6f..2c8ff6aebd 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -37,7 +37,7 @@ Aqua = "0.8.7" BLISBLAS = "0.1" BenchmarkTools = "1.5" ChainRulesCore = "1.24" -ComponentArrays = "0.15.16" +ComponentArrays = "0.15.18" Enzyme = "0.13.13" EnzymeCore = "0.8.5" ExplicitImports = "1.9.0" @@ -48,7 +48,7 @@ JLArrays = "0.1.5" LoopVectorization = "0.12.171" LuxTestUtils = "1.5" MKL = "0.7" -MLDataDevices = "1.0.0" +MLDataDevices = "1.6" NNlib = "0.9.21" Octavian = "0.3.28" Pkg = "1.10" @@ -61,5 +61,5 @@ Static = "0.8.4, 1" StaticArrays = "1.9.7" Statistics = "1.10" Test = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index aa1792deaa..3097c450c6 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.5.0" +version = "1.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,19 +21,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.8.1" +ADTypes = "1.10" ArrayInterface = "7.9" ChainRulesCore = "1.24.0" -ComponentArrays = "0.15.14" +ComponentArrays = "0.15.18" DispatchDoctor = "0.4.12" Enzyme = "0.13.13" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" -Functors = "0.4.11" +Functors = "0.5" JET = "0.9.6" -MLDataDevices = "1.0.0" +MLDataDevices = "1.6" ReverseDiff = "1.15.3" Test = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" Zygote = "0.6.70" julia = "1.10" diff --git a/lib/LuxTestUtils/test/Project.toml b/lib/LuxTestUtils/test/Project.toml index 3701de4ff2..9fdc23215e 100644 --- a/lib/LuxTestUtils/test/Project.toml +++ b/lib/LuxTestUtils/test/Project.toml @@ -9,7 +9,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] CUDA = "5" -ComponentArrays = "0.15" +ComponentArrays = "0.15.18" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" MetaTesting = "0.1" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f743b498f9..47d2266c54 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.5.3" +version = "1.6.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -54,9 +54,9 @@ Adapt = "4.1" CUDA = "5.2" ChainRules = "1.51" ChainRulesCore = "1.23" -Compat = "4.15" +Compat = "4.16" FillArrays = "1" -Functors = "0.4.8" +Functors = "0.5" GPUArrays = "10, 11" MLUtils = "0.4.4" Metal = "1" @@ -67,7 +67,7 @@ Reactant = "0.2.4" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl index f0b29a2d0c..b75874bfc2 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -4,6 +4,9 @@ using Adapt: Adapt, adapt using MLDataDevices: MLDataDevices, Internal, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray +MLDataDevices.isleaf(::VectorOfArray) = true +MLDataDevices.isleaf(::DiffEqArray) = true + # We want to preserve the structure function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 45e0a066ce..8b9c770f75 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -130,15 +130,21 @@ end for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice - not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ - $(unknown_ret_val)..." + all_not_assigned_msg = "AbstractArray has all undefined references. Giving up, \ + returning $(unknown_ret_val)..." + some_not_assigned_msg = "AbstractArray has some undefined references. Skipping over \ + unassigned indices..." @eval begin function $(op)(x::AbstractArray{T}) where {T} if recursive_array_eltype(T) - if any(!isassigned(x, i) for i in eachindex(x)) - @warn $(not_assigned_msg) + is_assigned_idxs = findall(Base.Fix1(isassigned, x), eachindex(x)) + if length(is_assigned_idxs) == 0 + @warn $(all_not_assigned_msg) return $(unknown_ret_val) + elseif 0 < length(is_assigned_idxs) < length(x) + @warn $(some_not_assigned_msg) + x = x[is_assigned_idxs] end return mapreduce(MLDataDevices.$(op), combine_devices, x) end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index d8c193ae36..13dcaba295 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -275,7 +275,7 @@ $(GET_DEVICE_ADMONITIONS) - `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract range, etc. - - `UnknownDevice()` -- denotes that the device type is unknown + - `UnknownDevice()` -- denotes that the device type is unknown. See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. @@ -295,7 +295,7 @@ $(GET_DEVICE_ADMONITIONS) - `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract range, etc. - - `UnknownDevice` -- denotes that the device type is unknown + - `UnknownDevice` -- denotes that the device type is unknown. """ function get_device_type end @@ -363,30 +363,16 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe end # Dispatches for Different Data Structures -# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability -# For all other types we rely on fmap which means we lose type stability. -# For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) - ldev = Symbol(dev, :Device) - @eval begin - function (D::$(ldev))(x::AbstractArray{T}) where {T} - if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray - return Adapt.adapt(D, x) - end - return map(D, x) - end - (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) - function (D::$(ldev))(x) - isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x; exclude=isleaf) - end - end +for dev in (CPUDevice, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) + @eval (D::$(dev))(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf) end for op in (:get_device, :get_device_type) @eval function $(op)(x) Internal.fast_structure(x) && return Internal.$(op)(x) - return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) + return mapreduce( + Internal.$(op), Internal.combine_devices, fleaves(x; exclude=isleaf) + ) end end @@ -396,9 +382,7 @@ Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) @eval begin - function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) - return default_device_rng(to) - end + Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) = default_device_rng(to) Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng end end @@ -420,5 +404,5 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct """ isleaf(x) = Functors.isleaf(x) -isleaf(::AbstractArray{T}) where {T} = isbitstype(T) +isleaf(::AbstractArray{T}) where {T} = isbitstype(T) || T <: Number # BigFloat and such are not bitstype isleaf(::Adapt.WrappedArray) = false diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 1fb732d376..e1e1f1e10c 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -25,12 +25,12 @@ Adapt = "4" Aqua = "0.8.4" ArrayInterface = "7.11" ChainRulesTestUtils = "1.13.0" -ComponentArrays = "0.15.8" +ComponentArrays = "0.15.18" ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" -Functors = "0.4.8" -MLUtils = "0.4" +Functors = "0.5" +MLUtils = "0.4.4" OneHotArrays = "0.2.5" Pkg = "1.10" Random = "1.10" @@ -39,5 +39,5 @@ ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" Zygote = "0.6.69" diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 55265ada34..65f63c9a91 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -52,7 +52,7 @@ end @testset "CRC Tests" begin dev = cpu_device() # Other devices don't work with FiniteDifferences.jl - test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true) + test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=false) gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems @@ -215,11 +215,11 @@ end x = rand(4, 4) cdev = cpu_device() - @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + @test get_device(only(Zygote.gradient(x -> sum(abs2, cdev(x)), x'))) isa CPUDevice gdev = gpu_device() - @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} + @test get_device(only(Zygote.gradient(x -> sum(abs2, gdev(x)), x'))) isa CPUDevice end @testset "Zygote and ChainRules OneElement #1016" begin @@ -256,3 +256,13 @@ end @test x_rd isa Reactant.ConcreteRArray{Bool, 2} end end + +@testset "Device Movement Behavior: FluxML/Flux.jl#2513" begin + dev = gpu_device() + + x = randn(5) + x2 = (x, x) + cx2 = dev(x2) + + @test cx2[1] === cx2[2] +end diff --git a/src/Lux.jl b/src/Lux.jl index a9d1ac552f..525e331fa6 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -83,6 +83,9 @@ include("transform/simplechains.jl") include("distributed/backend.jl") include("distributed/public_api.jl") +# Deprecations for v2 +include("deprecations.jl") + # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index 3e62563f32..a4b07170b8 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -6,7 +6,7 @@ using ChainRulesCore: ChainRulesCore using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using Functors: Functors, KeyPath, fmap_with_path, functor +using Functors: Functors, KeyPath, fmap_with_path, fmapstructure_with_path, functor using Markdown: @doc_str using Optimisers: Optimisers using Random: AbstractRNG, Random diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 7d7b388618..38c13f36fe 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -91,7 +91,8 @@ function check_nan_and_throw(x, str::AbstractString, layer, location::KeyPath) return x end - return fmap_with_path(nan_check, x) + fmapstructure_with_path(nan_check, x) + return end function debug_layer_impl(layer, x, ps, st, location, error_check, _) diff --git a/src/deprecations.jl b/src/deprecations.jl new file mode 100644 index 0000000000..c17940db7e --- /dev/null +++ b/src/deprecations.jl @@ -0,0 +1,12 @@ +# Recursive Operations --> Functors.fmap +@deprecate recursive_add!!(x, y) Functors.fmap( + Utils.add!!, x, y; exclude=MLDataDevices.isleaf +) +@deprecate recursive_make_zero(x) Functors.fmap(Utils.zero, x; exclude=MLDataDevices.isleaf) +@deprecate recursive_make_zero!!(x) Functors.fmap( + Utils.zero!!, x; exclude=MLDataDevices.isleaf +) +@deprecate recursive_copyto!(x, y) Functors.fmap( + copyto!, x, y; exclude=MLDataDevices.isleaf +) +@deprecate recursive_map(f, args...) Functors.fmap(f, args...; exclude=MLDataDevices.isleaf) diff --git a/src/helpers/match_eltype.jl b/src/helpers/match_eltype.jl index 529353ecc3..adbd8cd896 100644 --- a/src/helpers/match_eltype.jl +++ b/src/helpers/match_eltype.jl @@ -50,14 +50,17 @@ else arr -> match_eltype( layer, recursive_eltype((ps, st), Val(true)), Utils.eltype(arr), arr) end - return recursive_map(fn, x) + return fmap(fn, x; exclude=MLDataDevices.isleaf) end function match_eltype(layer, ps, st, x, args...) fn = let layer = layer arr -> match_eltype( layer, recursive_eltype((ps, st), Val(true)), Utils.eltype(arr), arr) end - return (recursive_map(fn, x), recursive_map(fn, args)...) + return ( + fmap(fn, x; exclude=MLDataDevices.isleaf), + fmap(fn, args; exclude=MLDataDevices.isleaf)... + ) end end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index 48b837dd1c..f7f91d66d6 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -1,3 +1,11 @@ +const RECURSIVE_OPS_DEPRECATION_MSG = """ +!!! warning "Deprecation Warning" + + Starting Lux v1.3.0, this function is deprecated in favor of `Functors.fmap`. Functors + v0.5 made significant strides towards improving the performance of `fmap` and hence + this function has been deprecated. Users are encouraged to use `Functors.fmap` instead. +""" + """ recursive_add!!(x, y) @@ -6,8 +14,10 @@ is equivalent to doing `fmap(+, x, y)`, but this implementation uses type stable common cases. Any leaves of `x` that are arrays and allow in-place addition will be modified in place. + +$(RECURSIVE_OPS_DEPRECATION_MSG) """ -recursive_add!!(x, y) = recursive_map(Utils.add!!, x, y) +function recursive_add!! end """ recursive_eltype(x, unwrap_ad_types = Val(false)) @@ -32,13 +42,11 @@ recursive_eltype(x::Number, ::True) = Utils.eltype(x) recursive_eltype(::Union{Nothing, Missing, Val}, ::StaticBool) = Bool function recursive_eltype(x::Union{Tuple, NamedTuple}, val::StaticBool) leaves = x isa Tuple ? x : values(x) - length(leaves) == 0 && return Bool - return mapreduce(Base.Fix2(recursive_eltype, val), promote_type, leaves) + return mapreduce(Base.Fix2(recursive_eltype, val), promote_type, leaves; init=Bool) end function recursive_eltype(x, val::StaticBool) - leaves = x isa Tuple ? x : (x isa NamedTuple ? values(x) : Functors.fleaves(x)) - length(leaves) == 0 && return Bool - return mapreduce(Base.Fix2(recursive_eltype, val), promote_type, leaves) + leaves = Functors.fleaves(x; exclude=MLDataDevices.isleaf) + return mapreduce(Base.Fix2(recursive_eltype, val), promote_type, leaves; init=Bool) end """ @@ -48,8 +56,10 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t `fmap(zero, x)`, but this implementation uses type stable code for common cases. See also [`Lux.recursive_make_zero!!`](@ref). + +$(RECURSIVE_OPS_DEPRECATION_MSG) """ -recursive_make_zero(x) = recursive_map(Utils.zero, x) +function recursive_make_zero end """ recursive_make_zero!!(x) @@ -58,8 +68,10 @@ Recursively create a zero value for a nested structure `x`. Leaves that can be m in-place zeroing will be modified in place. See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. + +$(RECURSIVE_OPS_DEPRECATION_MSG) """ -recursive_make_zero!!(x) = recursive_map(Utils.zero!!, x) +function recursive_make_zero!! end """ recursive_copyto!(x, y) @@ -67,8 +79,10 @@ recursive_make_zero!!(x) = recursive_map(Utils.zero!!, x) Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code for common cases. Note that any immutable leaf will lead to an error. + +$(RECURSIVE_OPS_DEPRECATION_MSG) """ -recursive_copyto!(x, y) = recursive_map(copyto!, x, y) +function recursive_copyto! end """ recursive_map(f, x, args...) @@ -77,6 +91,8 @@ Similar to `fmap(f, args...)` but with restricted support for the notion of "lea However, this allows for more efficient and type stable implementations of recursive operations. +$(RECURSIVE_OPS_DEPRECATION_MSG) + ## How this works? For the following types it directly defines recursion rules: @@ -95,21 +111,6 @@ For the following types it directly defines recursion rules: """ function recursive_map end -for direct_call in (Number, Val, Nothing) - @eval recursive_map(f::F, x::$(direct_call), args...) where {F} = f(x, args...) -end -function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T} - (T <: Number || isbitstype(T)) && return f(x, args...) # Not all Number types (BigFloat) are bitstype - return f.(x, args...) -end -function recursive_map(f::F, x::Union{NamedTuple, Tuple}, args...) where {F} - map_fn = let f = f - (args_...) -> recursive_map(f, args_...) - end - return map(map_fn, x, args...) -end -recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...) - @compat(public, (recursive_add!!, recursive_copyto!, recursive_eltype, recursive_make_zero, recursive_map, recursive_make_zero!!)) diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl index f6e37b4c1b..fc0d12b78a 100644 --- a/src/helpers/size_propagator.jl +++ b/src/helpers/size_propagator.jl @@ -13,7 +13,6 @@ using StaticArraysCore: StaticArraysCore const VecT = Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit} -using ..Lux: recursive_map using LuxLib: LuxLib using NNlib: NNlib @@ -130,12 +129,6 @@ end Base.fill!(dest::NilArray, _) = dest -recursively_nillify_internal(x) = x -recursively_nillify_internal(x::AbstractArray) = NilArray(x) - -recursively_nillify(x::AbstractArray{<:Real}) = recursively_nillify_internal(x) -recursively_nillify(x) = recursive_map(recursively_nillify_internal, x) - const Optional{T} = Union{Nothing, T} const Numeric = Union{<:Number, <:AbstractArray{<:Number}} @@ -187,9 +180,10 @@ end function LuxCore.outputsize(layer::AbstractLuxLayer, x, rng::AbstractRNG) ps, st = setup(rng, layer) - x_nil = NilSizePropagation.recursively_nillify(x) - ps_nil = NilSizePropagation.recursively_nillify(ps) - st_nil = NilSizePropagation.recursively_nillify(st) + fn = xᵢ -> xᵢ isa AbstractArray ? NilSizePropagation.NilArray(xᵢ) : xᵢ + x_nil = Functors.fmap(fn, x) + ps_nil = Functors.fmap(fn, ps) + st_nil = Functors.fmap(fn, st) y = first(apply(layer, x_nil, ps_nil, st_nil)) return Utils.unbatched_structure(y) end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 79ec6eaf24..da2b597a94 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,13 +4,14 @@ using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZyg using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure +using Functors: fmap using Optimisers: Optimisers using Setfield: @set! using Static: StaticBool, Static, False, True -using ..Lux: Lux +using ..Lux: Lux, Utils using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: ReactantDevice, get_device_type, get_device, cpu_device +using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, get_device, cpu_device """ TrainState @@ -81,7 +82,7 @@ end dparameters(cache::TrainingBackendCache) = dparameters(cache, cache.first_try) function dparameters(cache::TrainingBackendCache, ::False) - return Lux.recursive_make_zero!!(cache.dparameters) + return fmap(Utils.zero!!, cache.dparameters; exclude=MLDataDevices.isleaf) end dparameters(cache::TrainingBackendCache, ::True) = cache.dparameters diff --git a/test/Project.toml b/test/Project.toml index 1a2b73f0c3..ae07b7777a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -41,31 +41,31 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.8.1" +ADTypes = "1.10" Adapt = "4" Aqua = "0.8.4" ChainRulesCore = "1.24" -ComponentArrays = "0.15.16" +ComponentArrays = "0.15.18" DispatchDoctor = "0.4.12" Documenter = "1.4" Enzyme = "0.13.13" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" -Functors = "0.4.12" +Functors = "0.5" Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" Logging = "1.10" LoopVectorization = "0.12.171" -LuxCore = "1.0" +LuxCore = "1.2" LuxLib = "1.3.4" LuxTestUtils = "1.5" -MLDataDevices = "1.3" +MLDataDevices = "1.6" MLUtils = "0.4.3" NNlib = "0.9.24" Octavian = "0.3.28" OneHotArrays = "0.2.5" -Optimisers = "0.3.3, 0.4" +Optimisers = "0.3.4, 0.4" Pkg = "1.10" Preferences = "1.4.3" Random = "1.10" @@ -79,5 +79,5 @@ Static = "1" StaticArrays = "1.9" Statistics = "1.11.1" Test = "1.10" -Tracker = "0.2.34" +Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/test/runtests.jl b/test/runtests.jl index 11ba7a088c..4ae064c700 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,8 @@ if ("all" in LUX_TEST_GROUP || "misc" in LUX_TEST_GROUP) push!(EXTRA_PKGS, Pkg.PackageSpec("MPI")) (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, Pkg.PackageSpec("NCCL")) - push!(EXTRA_PKGS, Pkg.PackageSpec("Flux")) + # XXX: Reactivate once Flux is compatible with Functors 0.5 + # push!(EXTRA_PKGS, Pkg.PackageSpec("Flux")) end if !Sys.iswindows() diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index 6b86765221..38bc1c1764 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -1,4 +1,4 @@ -@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] begin +@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] skip=:(true) begin import Flux toluxpsst = FromFluxAdaptor(; preserve_ps_st=true) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 26c3663bb6..a9033d0a3a 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -187,7 +187,7 @@ end @test Lux.Utils.merge(abc, abc) == (a=1, b=2) end -@testitem "Recursive Utils" tags=[:misc] begin +@testitem "Recursive Utils (Deprecated)" tags=[:misc] begin using Functors, Tracker, ReverseDiff, ForwardDiff struct functorABC{A, B}