Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes ImageNet, SimpleRNN examples #499

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ W * x
# the `cu` function (or the `gpu` function exported by `Lux``), and it supports all of the
# above operations with the same syntax.

using LuxCUDA
using LuxCUDA, LuxAMDGPU

if LuxCUDA.functional()
x_cu = cu(rand(5, 3))
@show x_cu
elseif LuxAMDGPU.functional() # Similarly, for AMDGPU
x_amd = roc(rand(5, 3))
@show x_amd
end

# ## (Im)mutability
Expand Down
2 changes: 1 addition & 1 deletion examples/DDIM/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end

mkpath(output_dir)

if CUDA.functional()
if LuxCUDA.functional() || LuxAMDGPU.functional()
println("GPU is available.")
else
println("GPU is not available.")
Expand Down
1 change: 1 addition & 0 deletions examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ close(pkg_io) #hide
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
OptimizationOptimJL, Random, SciMLSensitivity
using CairoMakie, MakiePublication

CUDA.allowscalar(false)

# ## Define some Utility Functions
Expand Down
1 change: 1 addition & 0 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Pkg.precompile(; io=pkg_io) #hide
close(pkg_io) #hide
using Lux, ComponentArrays, LuxAMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays,
Optimisers, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

# ## Loading Datasets
Expand Down
2 changes: 2 additions & 0 deletions examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -14,6 +15,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
5 changes: 4 additions & 1 deletion examples/ImageNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

using Augmentor # Image Augmentation
using Boltz # Computer Vision Models
import Flux # Some Blotz Models need Flux
import Metalhead # Some Blotz Models need MetalHead
using Configurations # Experiment Configurations
using LuxAMDGPU # AMDGPUs <3
using LuxCUDA # NVIDIA GPUs <3
Expand Down Expand Up @@ -62,7 +64,8 @@ function construct(rng::AbstractRNG, cfg::ModelConfig, ecfg::ExperimentConfig)
end

function construct(cfg::OptimizerConfig)
if cfg.name == "adam" + opt = Adam(cfg.learning_rate)
if cfg.name == "adam"
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
opt = Adam(cfg.learning_rate)
elseif cfg.name == "sgd"
if cfg.nesterov
opt = Nesterov(cfg.learning_rate, cfg.momentum)
Expand Down
10 changes: 8 additions & 2 deletions examples/ImageNet/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
CUDA.allowscalar(false)

# unsafe_free OneHotArrays
CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices)

if LuxCUDA.functional()
# unsafe_free OneHotArrays
CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices)
elseif LuxAMDGPU.functional()
# unsafe_free OneHotArrays
AMDGPU.unsafe_free!(x::OneHotArray) = AMDGPU.unsafe_free!(x.indices)
end

# Loss Function
logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
Expand Down
1 change: 1 addition & 0 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Lux, ComponentArrays, SciMLSensitivity, LuxAMDGPU, LuxCUDA, Optimisers,
OrdinaryDiffEq, Random, Statistics, Zygote, OneHotArrays, InteractiveUtils
import MLDatasets: MNIST
import MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)

# ## Loading MNIST
Expand Down
4 changes: 2 additions & 2 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function (s::SpiralClassifier)(
## See that the parameters and states are automatically populated into a field called
## `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
## and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
## Now that we have the hidden state and memory in `carry` we will pass the input and
## `carry` jointly
Expand Down Expand Up @@ -119,7 +119,7 @@ function compute_loss(x, y, model, ps, st)
return binarycrossentropy(y_pred, y), y_pred, st
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# Finally lets create an optimiser given the model parameters.
Expand Down
Loading