Skip to content

Commit

Permalink
add some clarificatioins about GPU case in MNIST example
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 26, 2024
1 parent 1dfad77 commit 638a19c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 13 deletions.
116 changes: 105 additions & 11 deletions examples/mnist/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "cbc08fafc9a5fad1896170ca5b06f3d7e0fcb016"
project_hash = "3049fd46149696b9ac7df5214242bc2535d0a10e"

[[deps.ARFFFiles]]
deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
Expand Down Expand Up @@ -127,6 +127,46 @@ git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.10.14"

[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"]
git-tree-sha1 = "b8c28cb78014f7ae81a652ce1524cba7667dea5c"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "5.3.5"

[deps.CUDA.extensions]
ChainRulesCoreExt = "ChainRulesCore"
EnzymeCoreExt = "EnzymeCore"
SpecialFunctionsExt = "SpecialFunctions"

[deps.CUDA.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
git-tree-sha1 = "dc172b558adbf17952001e15cf0d6364e6d78c2f"
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
version = "0.8.1+0"

[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "38f830504358e9972d2a0c3e5d51cb865e0733df"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.2.4"

[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "4ca7d6d92075906c2ce871ea8bba971fff20d00c"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.12.1+0"

[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "9.0.0+1"

[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd"
Expand Down Expand Up @@ -437,6 +477,11 @@ git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7"
uuid = "2e619515-83b5-522b-bb60-26c02a35a201"
version = "2.6.2+0"

[[deps.ExprTools]]
git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.10"

[[deps.FFMPEG]]
deps = ["FFMPEG_jll"]
git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8"
Expand Down Expand Up @@ -573,6 +618,12 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950"
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.6"

[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "1600477fba37c9fc067b9be21f5e8101f24a8865"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.26.4"

[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"]
git-tree-sha1 = "ddda044ca260ee324c5fc07edb6d7cf3f0b9c350"
Expand Down Expand Up @@ -775,6 +826,12 @@ git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
version = "3.0.3+0"

[[deps.JuliaNVTXCallbacks_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f"
uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e"
version = "0.2.1+0"

[[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"]
git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70"
Expand Down Expand Up @@ -807,9 +864,9 @@ version = "3.0.0+1"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24"
git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "7.1.0"
version = "6.6.3"
weakdeps = ["BFloat16s"]

[deps.LLVM.extensions]
Expand All @@ -821,6 +878,11 @@ git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.29+0"

[[deps.LLVMLoopInfo]]
git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea"
uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586"
version = "1.0.0"

[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713"
Expand Down Expand Up @@ -1124,13 +1186,11 @@ deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "Lazy
git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
version = "0.9.3"
weakdeps = ["CUDA"]

[deps.Metalhead.extensions]
MetalheadCUDAExt = "CUDA"

[deps.Metalhead.weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[[deps.MicroCollections]]
deps = ["BangBang", "InitialValues", "Setfield"]
git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e"
Expand Down Expand Up @@ -1186,6 +1246,18 @@ git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f"
uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605"
version = "0.4.3"

[[deps.NVTX]]
deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"]
git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1"
uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
version = "0.3.4"

[[deps.NVTX_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b"
uuid = "e98f9f5b-d649-5603-91fd-7774390e6439"
version = "3.1.0+2"

[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
Expand Down Expand Up @@ -1411,9 +1483,9 @@ version = "0.4.2"

[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660"
git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "2.3.1"
version = "2.3.2"

[[deps.Printf]]
deps = ["Unicode"]
Expand Down Expand Up @@ -1456,6 +1528,18 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.Random123]]
deps = ["Random", "RandomNumbers"]
git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.7.0"

[[deps.RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.5.3"

[[deps.RealDot]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
Expand Down Expand Up @@ -1693,13 +1777,11 @@ deps = ["LinearAlgebra", "PackageExtensionCompat"]
git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e"
uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
version = "0.2.2"
weakdeps = ["CUDA"]

[deps.StridedViews.extensions]
StridedViewsCUDAExt = "CUDA"

[deps.StridedViews.weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[[deps.StringEncodings]]
deps = ["Libiconv_jll"]
git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb"
Expand Down Expand Up @@ -1772,6 +1854,12 @@ version = "0.1.1"
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.24"

[[deps.TranscodingStreams]]
git-tree-sha1 = "5d54d076465da49d6746c647022f3b3674e64156"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
Expand Down Expand Up @@ -2113,6 +2201,12 @@ git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.5"

[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"]
git-tree-sha1 = "1f6a185a8da9bbbc20134b7b935981f70c9b26ad"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
version = "1.3.1"

[[deps.eudev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"]
git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba"
Expand Down
2 changes: 2 additions & 0 deletions examples/mnist/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Expand All @@ -7,3 +8,4 @@ MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
7 changes: 5 additions & 2 deletions examples/mnist/notebook.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import MLJFlux
import MLUtils
import MLJIteration # for `skip`

# If running on a GPU, you will also need to `import CUDA` and `import cuDNN`.

using Plots
gr(size=(600, 300*(sqrt(5)-1)));

Expand Down Expand Up @@ -85,15 +87,16 @@ end
# is controlled using using the `finaliser` hyperparameter of the
# classifier.

# We now define the MLJ model. If you have a GPU, substitute
# `acceleration=CUDALibs()` below:
# We now define the MLJ model.

ImageClassifier = @load ImageClassifier
clf = ImageClassifier(
builder=MyConvBuilder(3, 16, 32, 32),
batch_size=50,
epochs=10,
rng=123,
# rng=Random.default_rng() # for GPU
# acceleration=CUDALibs(), # for GPU
)

# You can add Flux options `optimiser=...` and `loss=...` here. At
Expand Down

0 comments on commit 638a19c

Please sign in to comment.