Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Apr 1, 2024
1 parent 914d7a6 commit 8a089c5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
35 changes: 23 additions & 12 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,38 @@ flux_device(x::CUDA.CuArray) = flux_device(CUDA.device(x))
cuda_device(x::Flux.FluxCUDADevice) = x.deviceID
cuda_device(x) = cuda_device(flux_device(x))

flux_devices() = flux_device.(CUDA.devices())
flux_devices(xs) = flux_device.(xs)

cuda_devices() = collect(CUDA.devices())
cuda_devices(xs) = cuda_device.(xs)

set_device!(x::Flux.FluxCUDADevice) = CUDA.device!(Int(cuda_device(x).handle))
set_device!(x) = set_device!(flux_device(x))
device!(x::Flux.FluxCUDADevice) = CUDA.device!(Int(cuda_device(x).handle))
device!(x) = device!(flux_device(x))

function device!(f, x::Flux.FluxCUDADevice)
device!(x)
f()
end

flux_cuda_devices() = flux_device.(cuda_devices())

"""
withdevices(f, devices=flux_devices())
withdevices(f, devices=flux_cuda_devices())
Run `f` on each device in `devices`, returning a vector of the results.
"""
function withdevices(f, devices::Vector{Flux.FluxCUDADevice}=flux_devices())
CUDA.@sync map(enumerate(devices)) do (i, device)
CUDA.@async begin
set_device!(device)
function withdevices(f, devices=flux_cuda_devices(); async=true)
if async
CUDA.@sync map(enumerate(flux_devices(devices))) do (i, device)
CUDA.@async begin
device!(device)
f((i, device))
end
end .|> fetch
else
map(enumerate(flux_devices(devices))) do (i, device)
device!(device)
f((i, device))
end
end .|> fetch
end

withdevices(f, devices) = withdevices(f, flux_devices(devices))
end
end
2 changes: 1 addition & 1 deletion src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ julia> Flux.params.(models)
Params([Float32[0.79638827;;], Float32[0.0]])
```
"""
function allreduce!(op, xs::Vararg)
function allreduce!(op, xs...)
for arrays in zip(collect.(Flux.params.(xs))...)
allreduce!(op, arrays...)
end
Expand Down
5 changes: 3 additions & 2 deletions src/replicate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ end
Base.size(replicas::Replicas) = size(replicas.replicas)
Base.getindex(replicas::Replicas, i::Integer) = replicas.replicas[i]
Base.getindex(replicas::Replicas{D}, device::D) where D = replicas[findfirst(==(device), replicas.devices)]
Base.getindex(replicas::Replicas{D}, device) where D = replicas[flux_device(device)]

Base.summary(replicas::Replicas) = "$(length(replicas)) replicas on devices [$(join(map(d -> d.deviceID, replicas.devices), ", "))]"
Base.show(io::IO, replicas::Replicas) = print(io, summary(replicas))
Base.show(io::IO, ::MIME"text/plain", replicas::Replicas) = show(io, replicas)

"""
replicate(original, devices=flux_devices(), f=identity)
replicate(original, devices=flux_cuda_devices(), f=identity)
Replicate `original` across `devices`, optionally applying `f` to each replica (e.g. `deepcopy`).
"""
replicate(original, devices=flux_devices(), f=identity) = Replicas([f(original) for _ in devices], flux_devices(devices))
replicate(original, devices=flux_cuda_devices(), f=identity) = Replicas([f(original) for _ in devices], flux_devices(devices))

0 comments on commit 8a089c5

Please sign in to comment.