Skip to content

Commit

Permalink
rename "AMD" backend to "AMDGPU" (#2328)
Browse files Browse the repository at this point in the history
* rename AMD to AMDGPU

* more renaming

* deprecation

* better error messages in get_device

* deprecate binding FluxAMDAdaptor
  • Loading branch information
CarloLucibello authored Sep 7, 2023
1 parent f8d98d2 commit e39fa70
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ steps:
rocm: "*"
rocmgpu: "*"
commands: |
printf "[Flux]\ngpu_backend = \"AMD\"" > LocalPreferences.toml
printf "[Flux]\ngpu_backend = \"AMDGPU\"" > LocalPreferences.toml
timeout_in_minutes: 60
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
Expand Down
14 changes: 7 additions & 7 deletions docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ true

## Selecting GPU backend

Available GPU backends are: `CUDA`, `AMD` and `Metal`.
Available GPU backends are: `CUDA`, `AMDGPU` and `Metal`.

Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use.

There are two ways you can specify it:

- From the REPL/code in your project, call `Flux.gpu_backend!("AMD")` and restart (if needed) Julia session for the changes to take effect.
- From the REPL/code in your project, call `Flux.gpu_backend!("AMDGPU")` and restart (if needed) Julia session for the changes to take effect.
- In `LocalPreferences.toml` file in you project directory specify:
```toml
[Flux]
gpu_backend = "AMD"
gpu_backend = "AMDGPU"
```

Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:
Expand Down Expand Up @@ -296,7 +296,7 @@ julia> model.weight # no change; model still lives on CPU
```
Clearly, this means that the same code will work for any GPU backend and the CPU.

If the preference backend isn't available or isn't functional, then [`Flux.get_device`](@ref) looks for a CUDA, AMD or Metal backend, and returns a corresponding device (if the backend is available and functional). Otherwise, a CPU device is returned. In the below example, the GPU preference is `"CUDA"`:
If the preference backend isn't available or isn't functional, then [`Flux.get_device`](@ref) looks for a CUDA, AMDGPU or Metal backend, and returns a corresponding device (if the backend is available and functional). Otherwise, a CPU device is returned. In the below example, the GPU preference is `"CUDA"`:

```julia-repl
julia> using Flux; # preference is CUDA, but CUDA.jl not loaded
Expand Down Expand Up @@ -330,7 +330,7 @@ CUDA.DeviceIterator() for 3 devices:
Then, let's select the device with id `0`:

```julia-repl
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD"
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMDGPU"
(::Flux.FluxCUDADevice) (generic function with 1 method)
```
Expand Down Expand Up @@ -367,7 +367,7 @@ CuDevice(1): GeForce RTX 2080 Ti
```

Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMD` backends.
Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends.

!!! warning "Printing models after moving to a different device"

Expand All @@ -380,7 +380,7 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d
Flux.AbstractDevice
Flux.FluxCPUDevice
Flux.FluxCUDADevice
Flux.FluxAMDDevice
Flux.FluxAMDGPUDevice
Flux.FluxMetalDevice
Flux.supported_devices
Flux.get_device
Expand Down
16 changes: 8 additions & 8 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module FluxAMDGPUExt
import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, adapt_storage, fmap
import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib

Expand All @@ -17,16 +17,16 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
# Set to boolean on the first call to check_use_amdgpu
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxAMDDevice)(x)
function (device::Flux.FluxAMDGPUDevice)(x)
if device.deviceID === nothing
Flux.gpu(Flux.FluxAMDAdaptor(), x)
Flux.gpu(Flux.FluxAMDGPUAdaptor(), x)
else
return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
return Flux.gpu(Flux.FluxAMDGPUAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
end
end
Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD"
Flux._isavailable(::Flux.FluxAMDDevice) = true
Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional()
Flux._get_device_name(::Flux.FluxAMDGPUDevice) = "AMDGPU"
Flux._isavailable(::Flux.FluxAMDGPUDevice) = true
Flux._isfunctional(::Flux.FluxAMDGPUDevice) = AMDGPU.functional()

function check_use_amdgpu()
if !isnothing(USE_AMDGPU[])
Expand Down Expand Up @@ -55,7 +55,7 @@ include("conv.jl")

function __init__()
Flux.AMDGPU_LOADED[] = true
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMD"]] = AMDGPU.functional() ? Flux.FluxAMDDevice(AMDGPU.device()) : Flux.FluxAMDDevice(nothing)
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]] = AMDGPU.functional() ? Flux.FluxAMDGPUDevice(AMDGPU.device()) : Flux.FluxAMDGPUDevice(nothing)
end

# TODO
Expand Down
8 changes: 4 additions & 4 deletions ext/FluxAMDGPUExt/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat
b.λ.(_amd_batchnorm(
b.λ.(_amdgpu_batchnorm(
x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ,
within_grad=NNlib.within_gradient(x)))
end

function _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool)
function _amdgpu_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool)
if within_grad
return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ=Float64(ϵ), iteration=0) # TODO iteration
else
Expand All @@ -13,9 +13,9 @@ function _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool)
end

function ChainRulesCore.rrule(
::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ, within_grad::Bool,
::typeof(_amdgpu_batchnorm), x, γ, β; μ, σ², ϵ, within_grad::Bool,
)
y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
y, μ_saved, ν_saved = _amdgpu_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
function _batchnorm_pullback(Δ)
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
(NoTangent(), dx, dγ, dβ)
Expand Down
30 changes: 15 additions & 15 deletions ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Convert Float64 to Float32, but preserve Float16.
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
function adapt_storage(to::FluxAMDGPUAdaptor, x::AbstractArray)
if to.id === nothing
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
Expand Down Expand Up @@ -37,13 +37,13 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
end
end

adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) =
ROCArray(collect(x))
adapt_storage(::FluxAMDAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
adapt_storage(::FluxAMDAdaptor, x::Random.TaskLocalRNG) =
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) =
AMDGPU.rocRAND.default_rng()
adapt_storage(::FluxAMDAdaptor, x::AMDGPU.rocRAND.RNG) = x
adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x
adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error("""
Cannot map RNG of type $(typeof(x)) to AMDGPU.
AMDGPU execution only supports Random.default_rng().""")

Expand All @@ -54,7 +54,7 @@ function ChainRulesCore.rrule(
)
adapt_storage(to, x), dx -> (
NoTangent(), NoTangent(),
adapt_storage(FluxAMDAdaptor(), unthunk(dx)))
adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx)))
end

# Since MIOpen supports only cross-correlation as convolution,
Expand All @@ -66,25 +66,25 @@ const FLUX_CONV{M} = Union{
Flux.Conv{<:Any, <:Any, <:Any, <:M, <:Any},
Flux.ConvTranspose{<:Any, <:Any, <:Any, <:M, <:Any}}
const CPU_CONV = FLUX_CONV{Array}
const AMD_CONV = FLUX_CONV{ROCArray}
const AMDGPU_CONV = FLUX_CONV{ROCArray}

_conv_basetype(::Conv) = Conv
_conv_basetype(::ConvTranspose) = ConvTranspose

Flux._isleaf(::AMD_CONV) = true
Flux._isleaf(::AMDGPU_CONV) = true

_exclude(x) = Flux._isleaf(x)
_exclude(::CPU_CONV) = true

function _amd(id::Union{Nothing, Int}, x)
check_use_amdgpu()
USE_AMDGPU[] || return x
fmap(x -> Adapt.adapt(FluxAMDAdaptor(id), x), x; exclude=_exclude)
fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(id), x), x; exclude=_exclude)
end

# CPU -> GPU

function Adapt.adapt_structure(to::FluxAMDAdaptor, m::CPU_CONV)
function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV)
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
_conv_basetype(m)(
Adapt.adapt(to, m.σ),
Expand All @@ -95,21 +95,21 @@ end

# Don't adapt again.

Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = m
Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::AMDGPU_CONV) = m

# GPU -> CPU

function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV)
function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV)
dims = ntuple(i -> i, ndims(m.weight) - 2)
_conv_basetype(m)(
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
end

function Flux.get_device(::Val{:AMD}, id::Int) # id should start from 0
function Flux.get_device(::Val{:AMDGPU}, id::Int) # id should start from 0
old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0
AMDGPU.device!(AMDGPU.devices()[id + 1]) # adding 1 because ids start from 0
device = Flux.FluxAMDDevice(AMDGPU.device())
device = Flux.FluxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return device
end
1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...)
# v0.14 deprecations
@deprecate default_rng_value() Random.default_rng()

Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor

# v0.15 deprecations

Expand Down
45 changes: 28 additions & 17 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ _isleaf(::AbstractRNG) = true
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

# the order below is important
const GPU_BACKENDS = ["CUDA", "AMD", "Metal", "CPU"]
const GPU_BACKENDS = ["CUDA", "AMDGPU", "Metal", "CPU"]
const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS))))
const GPU_BACKEND = @load_preference("gpu_backend", "CUDA")

Expand Down Expand Up @@ -248,7 +248,10 @@ function gpu(x)
@static if GPU_BACKEND == "CUDA"
gpu(FluxCUDAAdaptor(), x)
elseif GPU_BACKEND == "AMD"
gpu(FluxAMDAdaptor(), x)
@warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1
gpu(FluxAMDGPUAdaptor(), x)
elseif GPU_BACKEND == "AMDGPU"
gpu(FluxAMDGPUAdaptor(), x)
elseif GPU_BACKEND == "Metal"
gpu(FluxMetalAdaptor(), x)
elseif GPU_BACKEND == "CPU"
Expand Down Expand Up @@ -355,13 +358,13 @@ function _cuda end

# AMDGPU extension. ########

Base.@kwdef struct FluxAMDAdaptor
Base.@kwdef struct FluxAMDGPUAdaptor
id::Union{Nothing, Int} = nothing
end

const AMDGPU_LOADED = Ref{Bool}(false)

function gpu(to::FluxAMDAdaptor, x)
function gpu(to::FluxAMDGPUAdaptor, x)
if AMDGPU_LOADED[]
return _amd(to.id, x)
else
Expand Down Expand Up @@ -457,7 +460,7 @@ end
"""
Flux.AbstractDevice <: Function
An abstract type representing `device` objects for different GPU backends. The currently supported backends are `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"`; the `"CPU"` backend is the fallback case when no GPU is available. GPU extensions of Flux define subtypes of this type.
An abstract type representing `device` objects for different GPU backends. The currently supported backends are `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"`; the `"CPU"` backend is the fallback case when no GPU is available. GPU extensions of Flux define subtypes of this type.
"""
abstract type AbstractDevice <: Function end
Expand Down Expand Up @@ -505,11 +508,11 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice
end

"""
FluxAMDDevice <: AbstractDevice
FluxAMDGPUDevice <: AbstractDevice
A type representing `device` objects for the `"AMD"` backend for Flux.
A type representing `device` objects for the `"AMDGPU"` backend for Flux.
"""
Base.@kwdef struct FluxAMDDevice <: AbstractDevice
Base.@kwdef struct FluxAMDGPUDevice <: AbstractDevice
deviceID
end

Expand Down Expand Up @@ -539,7 +542,7 @@ Get all supported backends for Flux, in order of preference.
julia> using Flux;
julia> Flux.supported_devices()
("CUDA", "AMD", "Metal", "CPU")
("CUDA", "AMDGPU", "Metal", "CPU")
```
"""
supported_devices() = GPU_BACKENDS
Expand All @@ -551,12 +554,12 @@ Returns a `device` object for the most appropriate backend for the current Julia
First, the function checks whether a backend preference has been set via the [`Flux.gpu_backend!`](@ref) function. If so, an attempt is made to load this backend. If the corresponding trigger package has been loaded and the backend is functional, a `device` corresponding to the given backend is loaded. Otherwise, the backend is chosen automatically. To update the backend preference, use [`Flux.gpu_backend!`](@ref).
If there is no preference, then for each of the `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"` backends in the given order, this function checks whether the given backend has been loaded via the corresponding trigger package, and whether the backend is functional. If so, the `device` corresponding to the backend is returned. If no GPU backend is available, a `Flux.FluxCPUDevice` is returned.
If there is no preference, then for each of the `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"` backends in the given order, this function checks whether the given backend has been loaded via the corresponding trigger package, and whether the backend is functional. If so, the `device` corresponding to the backend is returned. If no GPU backend is available, a `Flux.FluxCPUDevice` is returned.
If `verbose` is set to `true`, then the function prints informative log messages.
# Examples
For the example given below, the backend preference was set to `"AMD"` via the [`gpu_backend!`](@ref) function.
For the example given below, the backend preference was set to `"AMDGPU"` via the [`gpu_backend!`](@ref) function.
```julia-repl
julia> using Flux;
Expand All @@ -565,8 +568,8 @@ julia> model = Dense(2 => 3)
Dense(2 => 3) # 9 parameters
julia> device = Flux.get_device(; verbose=true) # this will just load the CPU device
[ Info: Using backend set in preferences: AMD.
┌ Warning: Trying to use backend: AMD but it's trigger package is not loaded.
[ Info: Using backend set in preferences: AMDGPU.
┌ Warning: Trying to use backend: AMDGPU but it's trigger package is not loaded.
│ Please load the package and call this function again to respect the preferences backend.
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:638
[ Info: Using backend: CPU.
Expand All @@ -591,8 +594,8 @@ julia> model = Dense(2 => 3)
Dense(2 => 3) # 9 parameters
julia> device = Flux.get_device(; verbose=true)
[ Info: Using backend set in preferences: AMD.
┌ Warning: Trying to use backend: AMD but it's trigger package is not loaded.
[ Info: Using backend set in preferences: AMDGPU.
┌ Warning: Trying to use backend: AMDGPU but it's trigger package is not loaded.
│ Please load the package and call this function again to respect the preferences backend.
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637
[ Info: Using backend: CUDA.
Expand Down Expand Up @@ -653,7 +656,7 @@ end
Flux.get_device(backend::String, idx::Int = 0)::Flux.AbstractDevice
Get a device object for a backend specified by the string `backend` and `idx`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `idx` must be an integer value between `0` and the number of available devices.
of `backend` are `"CUDA"`, `"AMDGPU"` and `"CPU"`. `idx` must be an integer value between `0` and the number of available devices.
# Examples
Expand Down Expand Up @@ -684,6 +687,10 @@ julia> cpu_device = Flux.get_device("CPU")
```
"""
function get_device(backend::String, idx::Int = 0)
if backend == "AMD"
@warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1
backend = "AMDGPU"
end
if backend == "CPU"
return FluxCPUDevice()
else
Expand All @@ -693,5 +700,9 @@ end

# Fallback
function get_device(::Val{D}, idx) where D
error("Unsupported backend: $(D). Try importing the corresponding package.")
if D (:CUDA, :AMDGPU, :Metal)
error("Unaivailable backend: $(D). Try importing the corresponding package with `using $D`.")
else
error("Unsupported backend: $(D). Supported backends are $(GPU_BACKENDS).")
end
end
Loading

0 comments on commit e39fa70

Please sign in to comment.