Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 12, 2024
1 parent e0440d4 commit 8cb1d5a
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Flux
import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib

using MLDataDevices: MLDataDevices
using AMDGPU
using Adapt
using Random
Expand Down
1 change: 1 addition & 0 deletions ext/FluxCUDAExt/FluxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ChainRulesCore
using Random
using Adapt
import Adapt: adapt_storage
using MLDataDevices: MLDataDevices


const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)
Expand Down
2 changes: 1 addition & 1 deletion ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ function _cuda(id::Union{Nothing, Int}, x)
end

function Flux._get_device(::Val{:CUDA}, id::Int)
return MLDataUtils.gpu_device(id+1, force=true)
return MLDataDevices.gpu_device(id+1, force=true)
end
2 changes: 1 addition & 1 deletion ext/FluxMetalExt/FluxMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Flux
import Flux: FluxCPUAdaptor, FluxMetalAdaptor, _metal, _isleaf, adapt_storage, fmap
import NNlib
using ChainRulesCore

using MLDataDevices: MLDataDevices
using Metal
using Adapt
using Random
Expand Down
2 changes: 1 addition & 1 deletion ext/FluxMetalExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ end

function Flux._get_device(::Val{:Metal}, id::Int)
@assert id == 0 "Metal backend only supports one device at the moment"
return MLDataDevices.gpu_device()
return MLDataDevices.gpu_device(force=true)
end

4 changes: 2 additions & 2 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)

@test_broken begin
@test begin
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
Expand All @@ -166,7 +166,7 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
@test diff1 diff2

true
end
end broken = VERSION >= v"1.11"
end

# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
Expand Down

0 comments on commit 8cb1d5a

Please sign in to comment.