From d3c0e0de7930bc862ec355edd614788f1079b040 Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Tue, 28 May 2024 14:17:11 +0530 Subject: [PATCH] flux to lux in bayes-nn --- .../03-bayesian-neural-network/Manifest.toml | 410 ++++++++++++------ .../03-bayesian-neural-network/Project.toml | 17 +- .../03-bayesian-neural-network/index.qmd | 142 +++--- 3 files changed, 365 insertions(+), 204 deletions(-) diff --git a/tutorials/03-bayesian-neural-network/Manifest.toml b/tutorials/03-bayesian-neural-network/Manifest.toml index cd6b05c77..f142e2b36 100755 --- a/tutorials/03-bayesian-neural-network/Manifest.toml +++ b/tutorials/03-bayesian-neural-network/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "52b869ceb2f168d1c74f4d0e17b52e09eedd43b7" +project_hash = "f93d3834307217e52e364d918f0d5a4e2a89ecea" [[deps.ADTypes]] git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24" @@ -243,6 +243,12 @@ git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.8" +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.5" + [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" @@ -254,6 +260,12 @@ git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.5.0" +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] +git-tree-sha1 = "585a387a490f1c4bd88be67eea15b93da5e85db7" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.2.5" + [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" @@ -292,6 +304,18 @@ weakdeps = ["InverseFunctions"] [deps.ChangesOfVariables.extensions] ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" +[[deps.ChunkSplitters]] +deps = ["Compat", "TestItems"] +git-tree-sha1 = "c7962ce1b964bde2867808235d1c521781df191e" +uuid = "ae650224-84b6-46f8-82ea-d812ca08434e" +version = "2.4.2" + +[[deps.CloseOpenIntervals]] +deps = ["Static", "StaticArrayInterface"] +git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.12" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" @@ -355,7 +379,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+1" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -366,6 +390,11 @@ weakdeps = ["InverseFunctions"] [deps.CompositionsBase.extensions] CompositionsBaseInverseFunctionsExt = "InverseFunctions" +[[deps.ConcreteStructs]] +git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" +uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +version = "0.2.3" + [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] git-tree-sha1 = "6cbbd4d241d7e6579ab354737f4dd95ca43946e1" @@ -389,17 +418,17 @@ weakdeps = ["IntervalSets", "StaticArrays"] ConstructionBaseIntervalSetsExt = "IntervalSets" ConstructionBaseStaticArraysExt = "StaticArrays" -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - [[deps.Contour]] git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.3" +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -510,6 +539,7 @@ deps = ["ADTypes", "AbstractMCMC", "AbstractPPL", "BangBang", "Bijectors", "Comp git-tree-sha1 = "839b5a5257047c2fe47946e84a706e37d9cfee27" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" version = "0.24.11" +weakdeps = ["ChainRulesCore", "EnzymeCore", "ForwardDiff", "MCMCChains", "ReverseDiff", "ZygoteRules"] [deps.DynamicPPL.extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] @@ -519,14 +549,6 @@ version = "0.24.11" DynamicPPLReverseDiffExt = ["ReverseDiff"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] - [deps.DynamicPPL.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" - [[deps.EllipticalSliceSampling]] deps = ["AbstractMCMC", "ArrayInterface", "Distributions", "Random", "Statistics"] git-tree-sha1 = "e611b7fdfbfb5b18d5e98776c30daede41b44542" @@ -538,6 +560,15 @@ git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" version = "1.0.4" +[[deps.EnzymeCore]] +git-tree-sha1 = "0910982db2490a20f81dc7db5d4bbea236c027b3" +uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" +version = "0.7.3" +weakdeps = ["Adapt"] + + [deps.EnzymeCore.extensions] + AdaptExt = "Adapt" + [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" @@ -585,17 +616,16 @@ git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" version = "3.3.10+0" -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.1" +[[deps.FastBroadcast]] +deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] +git-tree-sha1 = "a6e756a880fc419c8b41592010aebe6a5ce09136" +uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +version = "0.2.8" -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" +[[deps.FastClosures]] +git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" +uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +version = "0.3.2" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -618,24 +648,6 @@ git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.15" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - [[deps.Fontconfig_jll]] deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" @@ -696,12 +708,6 @@ git-tree-sha1 = "ff38ba61beff76b8f4acad8ab0c97ef73bb670cb" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" version = "3.3.9+0" -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "68e8ff56a4a355a85d2784b94614491f8c900cde" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.1.0" - [[deps.GPUArraysCore]] deps = ["Adapt"] git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" @@ -761,11 +767,10 @@ git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.23" -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -863,24 +868,16 @@ git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "3.0.3+0" -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" version = "0.9.19" +weakdeps = ["EnzymeCore"] [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - [[deps.KernelDensity]] deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] git-tree-sha1 = "7d703202e65efa1369de1279c162b915e245eed1" @@ -901,9 +898,9 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" +git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.3" +version = "7.1.0" [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" @@ -957,6 +954,12 @@ version = "0.16.3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "62edfee3211981241b57ff1cedf4d74d79519277" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.15" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -1105,6 +1108,106 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.3" +[[deps.Lux]] +deps = ["ADTypes", "Adapt", "ArgCheck", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "OhMyThreads", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "WeightInitializers"] +git-tree-sha1 = "93c0d182dbcf2dfe1e8f3e68751979f949fca5e6" +uuid = "b2108857-7c20-44ae-9111-449ecde12c47" +version = "0.5.51" + + [deps.Lux.extensions] + LuxComponentArraysExt = "ComponentArrays" + LuxDynamicExpressionsExt = "DynamicExpressions" + LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"] + LuxEnzymeExt = "Enzyme" + LuxFluxExt = "Flux" + LuxForwardDiffExt = "ForwardDiff" + LuxLuxAMDGPUExt = "LuxAMDGPU" + LuxMLUtilsExt = "MLUtils" + LuxMPIExt = "MPI" + LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] + LuxOptimisersExt = "Optimisers" + LuxReverseDiffExt = "ReverseDiff" + LuxSimpleChainsExt = "SimpleChains" + LuxTrackerExt = "Tracker" + LuxZygoteExt = "Zygote" + + [deps.Lux.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" + Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.LuxCore]] +deps = ["Functors", "Random", "Setfield"] +git-tree-sha1 = "c96985555a9fe41d7ec2bd5625d6c2077e05e33e" +uuid = "bb33d45b-7691-41d6-9220-0943567d0623" +version = "0.1.15" + +[[deps.LuxDeviceUtils]] +deps = ["Adapt", "ChainRulesCore", "FastClosures", "Functors", "LuxCore", "PrecompileTools", "Preferences", "Random"] +git-tree-sha1 = "bbcf12d598b8ef6d2b12e506b1d18125552c3b27" +uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" +version = "0.1.20" + + [deps.LuxDeviceUtils.extensions] + LuxDeviceUtilsAMDGPUExt = "AMDGPU" + LuxDeviceUtilsCUDAExt = "CUDA" + LuxDeviceUtilsFillArraysExt = "FillArrays" + LuxDeviceUtilsGPUArraysExt = "GPUArrays" + LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" + LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" + LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] + LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" + LuxDeviceUtilsSparseArraysExt = "SparseArrays" + LuxDeviceUtilsZygoteExt = "Zygote" + + [deps.LuxDeviceUtils.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.LuxLib]] +deps = ["ArrayInterface", "ChainRulesCore", "EnzymeCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "02920ad8b5f7c8a24cb32fb29dd990eac944cd71" +uuid = "82251201-b29d-42c6-8e01-566dec8acb11" +version = "0.3.26" + + [deps.LuxLib.extensions] + LuxLibAMDGPUExt = "AMDGPU" + LuxLibCUDAExt = "CUDA" + LuxLibForwardDiffExt = "ForwardDiff" + LuxLibReverseDiffExt = "ReverseDiff" + LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] + LuxLibTrackerExt = "Tracker" + LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] + LuxLibcuDNNExt = ["CUDA", "cuDNN"] + + [deps.LuxLib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + [[deps.MCMCChains]] deps = ["AbstractMCMC", "AxisArrays", "Dates", "Distributions", "IteratorInterfaceExtensions", "KernelDensity", "LinearAlgebra", "MCMCDiagnosticTools", "MLJModelInterface", "NaturalSort", "OrderedCollections", "PrettyTables", "Random", "RecipesBase", "Statistics", "StatsBase", "StatsFuns", "TableTraits", "Tables"] git-tree-sha1 = "d28056379864318172ff4b7958710cfddd709339" @@ -1129,23 +1232,17 @@ git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" version = "1.9.6" -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.13" +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + [[deps.MappedArrays]] git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" @@ -1192,9 +1289,9 @@ version = "2023.1.10" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "e0cea7ec219ada9ac80ec2e82e374ab2f154ae05" +git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.16" +version = "0.9.17" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1214,17 +1311,11 @@ git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "1.0.2" -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - [[deps.NamedArrays]] deps = ["Combinatorics", "DataStructures", "DelimitedFiles", "InvertedIndices", "LinearAlgebra", "Random", "Requires", "SparseArrays", "Statistics"] -git-tree-sha1 = "0ae91efac93c3859f5c812a24c9468bb9e50b028" +git-tree-sha1 = "c7aab3836df3f31591a2b4167fcd87b741dacfc9" uuid = "86f7a689-2022-50b4-a561-43c23ac3c673" -version = "0.10.1" +version = "0.10.2" [[deps.NaturalSort]] git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" @@ -1250,16 +1341,16 @@ git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" version = "1.3.5+1" -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" +[[deps.OhMyThreads]] +deps = ["BangBang", "ChunkSplitters", "StableTasks", "TaskLocalValues"] +git-tree-sha1 = "4b43015960c9e1b660cfae4c1b19c7ed9c86b92c" +uuid = "67456a42-1dca-4109-a031-0a68de7e3ad5" +version = "0.5.2" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+2" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1318,6 +1409,12 @@ git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.8.1" +[[deps.PartialFunctions]] +deps = ["MacroTools"] +git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" +uuid = "570af359-4316-4cb7-8c74-252c00c2016b" +version = "1.2.0" + [[deps.Pipe]] git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" @@ -1336,9 +1433,9 @@ version = "1.10.0" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "1f03a2d339f42dca4a4da149c7e15e9b896ad899" +git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.1.0" +version = "3.2.0" [[deps.PlotUtils]] deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] @@ -1366,6 +1463,18 @@ version = "1.40.4" ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +[[deps.Polyester]] +deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Requires", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] +git-tree-sha1 = "b3e2bae88cf07baf0a051fe09666b8ef97aefe93" +uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" +version = "0.7.14" + +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.2.1" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -1378,16 +1487,11 @@ git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.3" -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.1" +version = "2.3.2" [[deps.Printf]] deps = ["Unicode"] @@ -1406,9 +1510,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.0" [[deps.PtrArrays]] -git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913" +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.1.0" +version = "1.2.0" [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] @@ -1562,11 +1666,16 @@ version = "0.5.13" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + [[deps.SciMLBase]] deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "265f1a7a804d8093fa0b17e33e45373a77e56ca5" +git-tree-sha1 = "9f59654e2a85017ee27b0f59c7fac5a57aa10ced" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.38.0" +version = "2.39.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -1622,11 +1731,6 @@ version = "1.1.1" deps = ["Distributed", "Mmap", "Random", "Serialization"] uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - [[deps.Showoff]] deps = ["Dates", "Grisu"] git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" @@ -1638,12 +1742,6 @@ git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" version = "1.1.0" -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - [[deps.SimpleUnPack]] git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437" uuid = "ce78b400-467f-4804-87d8-8f486da07d0a" @@ -1685,6 +1783,28 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" +[[deps.StableTasks]] +git-tree-sha1 = "073d5c20d44129b20fe954720b97069579fa403b" +uuid = "91464d47-22a1-43fe-8b7f-2d57ee82463f" +version = "0.1.5" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.8.10" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Requires", "SparseArrays", "Static", "SuiteSparse"] +git-tree-sha1 = "5d66818a39bb04bf328e92bc933ec5b4ee88e436" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.5.0" +weakdeps = ["OffsetArrays", "StaticArrays"] + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] git-tree-sha1 = "9ae599cd7529cfce7fea36cf00a62cfc56f0f37c" @@ -1735,6 +1855,12 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"] StatsFunsChainRulesCoreExt = "ChainRulesCore" StatsFunsInverseFunctionsExt = "InverseFunctions" +[[deps.StrideArraysCore]] +deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] +git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" +uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" +version = "0.5.6" + [[deps.StringManipulation]] deps = ["PrecompileTools"] git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" @@ -1791,6 +1917,11 @@ deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" +[[deps.TaskLocalValues]] +git-tree-sha1 = "eb0b8d147eb907a9ad3fd952da7c6a053b29ae28" +uuid = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" +version = "0.1.1" + [[deps.TensorCore]] deps = ["LinearAlgebra"] git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" @@ -1807,6 +1938,17 @@ version = "0.1.7" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.TestItems]] +git-tree-sha1 = "8621ba2637b49748e2dc43ba3d84340be2938022" +uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" +version = "0.1.1" + +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.2" + [[deps.Tracker]] deps = ["Adapt", "ChainRulesCore", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] git-tree-sha1 = "5158100ed55411867674576788e710a815a0af02" @@ -1902,9 +2044,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" +version = "0.1.4" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -1929,6 +2071,18 @@ git-tree-sha1 = "93f43ab61b16ddfb2fd3bb13b3ce241cafb0e6c9" uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" version = "1.31.0+0" +[[deps.WeightInitializers]] +deps = ["ChainRulesCore", "LinearAlgebra", "PartialFunctions", "PrecompileTools", "Random", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "f0e6760ef9d22f043710289ddf29e4a4048c4822" +uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +version = "0.1.7" + + [deps.WeightInitializers.extensions] + WeightInitializersCUDAExt = "CUDA" + + [deps.WeightInitializers.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [[deps.WoodburyMatrices]] deps = ["LinearAlgebra", "SparseArrays"] git-tree-sha1 = "c1a7aa6219628fcd757dede0ca95e245c5cd9511" @@ -2108,22 +2262,6 @@ git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" version = "1.5.6+0" -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" diff --git a/tutorials/03-bayesian-neural-network/Project.toml b/tutorials/03-bayesian-neural-network/Project.toml index 799aa0e63..682c21d24 100755 --- a/tutorials/03-bayesian-neural-network/Project.toml +++ b/tutorials/03-bayesian-neural-network/Project.toml @@ -1,8 +1,9 @@ -[deps] -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" +[deps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/tutorials/03-bayesian-neural-network/index.qmd b/tutorials/03-bayesian-neural-network/index.qmd index ab1fc4b34..c34280567 100755 --- a/tutorials/03-bayesian-neural-network/index.qmd +++ b/tutorials/03-bayesian-neural-network/index.qmd @@ -17,9 +17,10 @@ We will begin with importing the relevant libraries. ```{julia} using Turing using FillArrays -using Flux +using Lux using Plots using ReverseDiff +using Functors using LinearAlgebra using Random @@ -29,34 +30,31 @@ Our goal here is to use a Bayesian neural network to classify points in an artif The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we will be working with. ```{julia} -# Number of points to generate. +# Number of points to generate N = 80 M = round(Int, N / 4) -Random.seed!(1234) - -# Generate artificial data. -x1s = rand(M) * 4.5; -x2s = rand(M) * 4.5; -xt1s = Array([[x1s[i] + 0.5; x2s[i] + 0.5] for i in 1:M]) -x1s = rand(M) * 4.5; -x2s = rand(M) * 4.5; -append!(xt1s, Array([[x1s[i] - 5; x2s[i] - 5] for i in 1:M])) - -x1s = rand(M) * 4.5; -x2s = rand(M) * 4.5; -xt0s = Array([[x1s[i] + 0.5; x2s[i] - 5] for i in 1:M]) -x1s = rand(M) * 4.5; -x2s = rand(M) * 4.5; -append!(xt0s, Array([[x1s[i] - 5; x2s[i] + 0.5] for i in 1:M])) - -# Store all the data for later. +rng = Random.default_rng() +Random.seed!(rng, 1234) + +# Generate artificial data +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M]) +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M])) + +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M]) +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M])) + +# Store all the data for later xs = [xt1s; xt0s] ts = [ones(2 * M); zeros(2 * M)] -# Convert xs to Float32 -xs = hcat(xs...) -xs = convert(Array{Float32}, xs) - # Plot data points. function plot_data() x1 = map(e -> e[1], xt1s) @@ -74,7 +72,7 @@ plot_data() ## Building a Neural Network The next step is to define a [feedforward neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network) where we express our parameters as distributions, and not single points as with traditional neural networks. -For this we will use `Dense` to define liner layers and compose them via `Chain`, both are neural network primitives from Flux. +For this we will use `Dense` to define liner layers and compose them via `Chain`, both are neural network primitives from Lux. The network `nn_initial` we created has two hidden layers with `tanh` activations and one output layer with sigmoid (`σ`) activation, as shown below. ```{dot} @@ -143,36 +141,62 @@ graph G { ``` The `nn_initial` is an instance that acts as a function and can take data as inputs and output predictions. -We will define distributions on the neural network parameters and use `destructure` from Flux to extract the parameters as `parameters_initial`. -The function `destructure` also returns another function `reconstruct` that can take (new) parameters in and return us a neural network instance whose architecture is the same as `nn_initial` but with updated parameters. +We will define distributions on the neural network parameters. + ```{julia} -# Construct a neural network using Flux -nn_initial = Chain(Dense(2, 3, tanh), Dense(3, 2, tanh), Dense(2, 1, σ)) |> f32 +# Construct a neural network using Lux +nn_initial = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, σ)) -# Extract weights and a helper function to reconstruct NN from weights -parameters_initial, reconstruct = Flux.destructure(nn_initial) +# Initialize the model weights and state +ps, st = Lux.setup(rng, nn_initial) -length(parameters_initial) # number of paraemters in NN +Lux.parameterlength(nn_initial) # number of paraemters in NN ``` The probabilistic model specification below creates a `parameters` variable, which has IID normal variables. The `parameters` vector represents all parameters of our neural net (weights and biases). ```{julia} -@model function bayes_nn(xs, ts, nparameters, reconstruct; alpha=0.09) - # Create the weight and bias vector. - parameters ~ MvNormal(Zeros(nparameters), I / alpha) +# Create a regularization term and a Gaussian prior variance term. +alpha = 0.09 +sig = sqrt(1.0 / alpha) +``` + +Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies. + +```{julia} +function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) + @assert length(ps_new) == Lux.parameterlength(ps) + i = 1 + function get_ps(x) + z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) + i += length(x) + return z + end + return fmap(get_ps, ps) +end +``` + +To interface with external libraries it is often desirable to use the [`StatefulLuxLayer`](https://lux.csail.mit.edu/stable/api/Lux/contrib#StatefulLuxLayer) to automatically handle the neural network states. + +```{julia} +const model = StatefulLuxLayer(nn_initial, st) + +# Specify the probabilistic model. +@model function bayes_nn(xs, ts) + # Sample the parameters + nparameters = Lux.parameterlength(nn_initial) + parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters)))) - # Construct NN from parameters - nn = reconstruct(parameters) # Forward NN to make predictions - preds = nn(xs) + preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps)) # Observe each prediction. - for i in 1:length(ts) + for i in eachindex(ts) ts[i] ~ Bernoulli(preds[i]) end -end; +end ``` Inference can now be performed by calling `sample`. We use the `NUTS` Hamiltonian Monte Carlo sampler here. @@ -185,15 +209,15 @@ setprogress!(false) ```{julia} # Perform inference. N = 5000 -ch = sample(bayes_nn(xs, ts, length(parameters_initial), reconstruct), NUTS(;adtype=AutoReverseDiff()), N); +ch = sample(bayes_nn(reduce(hcat, xs), ts), NUTS(; adtype=AutoReverseDiff()), N); ``` -Now we extract the parameter samples from the sampled chain as `theta` (this is of size `5000 x 20` where `5000` is the number of iterations and `20` is the number of parameters). +Now we extract the parameter samples from the sampled chain as `θ` (this is of size `5000 x 20` where `5000` is the number of iterations and `20` is the number of parameters). We'll use these primarily to determine how good our model's classifier is. ```{julia} # Extract all weight and bias parameters. -theta = convert(Array{Float32}, MCMCChains.group(ch, :parameters).value) +θ = MCMCChains.group(ch, :parameters).value; ``` ## Prediction Visualization @@ -201,11 +225,11 @@ theta = convert(Array{Float32}, MCMCChains.group(ch, :parameters).value) We can use [MAP estimation](https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation) to classify our population by using the set of weights that provided the highest log posterior. ```{julia} -# A helper to create NN from weights `theta` and run it through data `x` -nn_forward(x, theta) = reconstruct(theta)(x) +# A helper to run the nn through data `x` using parameters `θ` +nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps)) # Plot the data we have. -plot_data() +fig = plot_data() # Find the index that provided the highest log posterior in the chain. _, i = findmax(ch[:lp]) @@ -216,8 +240,9 @@ i = i.I[1] # Plot the posterior distribution with a contour plot x1_range = collect(range(-6; stop=6, length=25)) x2_range = collect(range(-6; stop=6, length=25)) -Z = [nn_forward(Float32[x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] -contour!(x1_range, x2_range, Z) +Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range] +contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright) +fig ``` The contour plot above shows that the MAP method is not too bad at classifying our data. @@ -233,9 +258,8 @@ The `nn_predict` function takes the average predicted value from a network param ```{julia} # Return the average predicted value across # multiple weights. -function nn_predict(x, theta, num) - x = convert(Vector{Float32}, x) # Ensure x is Float32 - return mean([nn_forward(x, theta[i, :])[1] for i in 1:10:num]) +function nn_predict(x, θ, num) + return mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num]) end ``` @@ -244,27 +268,25 @@ Next, we use the `nn_predict` function to predict the value at a sample of point ```{julia} # Plot the average prediction. -plot_data() +fig = plot_data() n_end = 1500 -x1_range = collect(range(-6, stop=6, length=25)) -x2_range = collect(range(-6, stop=6, length=25)) - -# Ensure x1, x2 are Float32 within the comprehension -Z = [nn_predict(Float32[x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range] -contour!(x1_range, x2_range, Z) +x1_range = collect(range(-6; stop=6, length=25)) +x2_range = collect(range(-6; stop=6, length=25)) +Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range] +contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright) +fig ``` Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 1,000. - ```{julia} # Number of iterations to plot. n_end = 500 anim = @gif for i in 1:n_end plot_data() - Z = [nn_forward(Float32[x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] + Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1)) end every 5 ```