diff --git a/examples/mnist/Manifest.toml b/examples/mnist/Manifest.toml index 50376ad9..29c5e94b 100644 --- a/examples/mnist/Manifest.toml +++ b/examples/mnist/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "cbc08fafc9a5fad1896170ca5b06f3d7e0fcb016" +project_hash = "3049fd46149696b9ac7df5214242bc2535d0a10e" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -127,6 +127,46 @@ git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.14" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] +git-tree-sha1 = "b8c28cb78014f7ae81a652ce1524cba7667dea5c" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "5.3.5" + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + EnzymeCoreExt = "EnzymeCore" + SpecialFunctionsExt = "SpecialFunctions" + + [deps.CUDA.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "dc172b558adbf17952001e15cf0d6364e6d78c2f" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.8.1+0" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "38f830504358e9972d2a0c3e5d51cb865e0733df" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.4" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "4ca7d6d92075906c2ce871ea8bba971fff20d00c" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.12.1+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "9.0.0+1" + [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" @@ -437,6 +477,11 @@ git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" version = "2.6.2+0" +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + [[deps.FFMPEG]] deps = ["FFMPEG_jll"] git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" @@ -573,6 +618,12 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "1600477fba37c9fc067b9be21f5e8101f24a8865" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.26.4" + [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] git-tree-sha1 = "ddda044ca260ee324c5fc07edb6d7cf3f0b9c350" @@ -775,6 +826,12 @@ git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "3.0.3+0" +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" + [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" @@ -807,9 +864,9 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24" +git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "7.1.0" +version = "6.6.3" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -821,6 +878,11 @@ git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" version = "0.0.29+0" +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" + [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" @@ -1124,13 +1186,11 @@ deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "Lazy git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" version = "0.9.3" +weakdeps = ["CUDA"] [deps.Metalhead.extensions] MetalheadCUDAExt = "CUDA" - [deps.Metalhead.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" @@ -1186,6 +1246,18 @@ git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" version = "0.4.3" +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.4" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" @@ -1411,9 +1483,9 @@ version = "0.4.2" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.1" +version = "2.3.2" [[deps.Printf]] deps = ["Unicode"] @@ -1456,6 +1528,18 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.7.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + [[deps.RealDot]] deps = ["LinearAlgebra"] git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" @@ -1693,13 +1777,11 @@ deps = ["LinearAlgebra", "PackageExtensionCompat"] git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" version = "0.2.2" +weakdeps = ["CUDA"] [deps.StridedViews.extensions] StridedViewsCUDAExt = "CUDA" - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - [[deps.StringEncodings]] deps = ["Libiconv_jll"] git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" @@ -1772,6 +1854,12 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.24" + [[deps.TranscodingStreams]] git-tree-sha1 = "5d54d076465da49d6746c647022f3b3674e64156" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" @@ -2113,6 +2201,12 @@ git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.5" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "1f6a185a8da9bbbc20134b7b935981f70c9b26ad" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.3.1" + [[deps.eudev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" diff --git a/examples/mnist/Project.toml b/examples/mnist/Project.toml index 872f30c8..94a789a2 100644 --- a/examples/mnist/Project.toml +++ b/examples/mnist/Project.toml @@ -1,4 +1,5 @@ [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" @@ -7,3 +8,4 @@ MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/examples/mnist/notebook.jl b/examples/mnist/notebook.jl index a5562eec..b3defbf6 100644 --- a/examples/mnist/notebook.jl +++ b/examples/mnist/notebook.jl @@ -13,6 +13,8 @@ import MLJFlux import MLUtils import MLJIteration # for `skip` +# If running on a GPU, you will also need to `import CUDA` and `import cuDNN`. + using Plots gr(size=(600, 300*(sqrt(5)-1))); @@ -85,8 +87,7 @@ end # is controlled using using the `finaliser` hyperparameter of the # classifier. -# We now define the MLJ model. If you have a GPU, substitute -# `acceleration=CUDALibs()` below: +# We now define the MLJ model. ImageClassifier = @load ImageClassifier clf = ImageClassifier( @@ -94,6 +95,8 @@ clf = ImageClassifier( batch_size=50, epochs=10, rng=123, +# rng=Random.default_rng() # for GPU +# acceleration=CUDALibs(), # for GPU ) # You can add Flux options `optimiser=...` and `loss=...` here. At