From 78addb86ac975349944748bef73de87ec1272427 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 16 Nov 2022 11:11:15 -0500 Subject: [PATCH 1/2] init --- Manifest.toml | 702 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 702 insertions(+) create mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..e53e048 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,702 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.9.0-DEV" +manifest_format = "2.0" +project_hash = "67c4978faee91f2c50a70573bb50512caec0c109" + +[[deps.AbstractFFTs]] +deps = ["ChainRulesCore", "LinearAlgebra"] +git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.2.1" + +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "eb7a1342ff77f4f9b6552605f27fd432745a53a3" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.22" + +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.4.0" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArrayInterface]] +deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"] +git-tree-sha1 = "d6173480145eb632d6571c148d94b9d3d773820e" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "6.0.23" + +[[deps.ArrayInterfaceCore]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "c46fb7dd1d8ca1d213ba25848a5ec4e47a1a1b08" +uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" +version = "0.1.26" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.37" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.2" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "49549e2c28ffb9cc77b3689dc10e46e6271e9452" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.12.0" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "0c8c8887763f42583e1206ee35413a43c91e2623" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.45.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.15.6" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.4" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.3.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.4.1" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.DataAPI]] +git-tree-sha1 = "e08915633fcb3ea83bf9d6126292e5bc5c739922" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.13.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.13" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "19b1417ff479c07e523fcbf2fd735a3fde3d1ab3" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.0" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.12.2" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "c36550cb29cbe373e95b3f40486b9a4148f89ffd" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.2" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.8" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.1" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "802bfc139833d2ba893dd9e62ba1767c88d708ae" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.5" + +[[deps.Flux]] +deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] +git-tree-sha1 = "66b62bf72c4b5d4904441ed0677eab53266033c7" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.13.7" + +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "10fa12fe96e4d76acfa738f4df2126589a67374f" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.33" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "a2657dd0f3e8a61dbe70fc7c122038bd33790af5" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.3.0" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "45d7deaf05cbb44116ba785d147c518ab46352d7" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.5.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.2" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "76f70a337a153c1632104af19d29023dbb6f30dd" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.16.6" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.7" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.8" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.4.1" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "e7e9184b0bf0158ac4e4aa9daf00041b5909bf1a" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.14.0" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.16+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "94d9c52ca447e23eac0c0f074effbcd38830deb5" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.18" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MLStyle]] +git-tree-sha1 = "060ef7956fef2dc06b0e63b294f7dbfbcbdc7ea2" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.16" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase", "Transducers"] +git-tree-sha1 = "824e9dfc7509cab1ec73ba77b55a916bb2905e26" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.2.11" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.10" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.3" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.10.11" + +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "00bcfcea7b2063807fdcab2e0ce86ef00b8b8000" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.8.10" + +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "4429261364c5ea5b7308aecaa10e803ace101631" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.4" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.1" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "GPUArrays", "LinearAlgebra", "MLUtils", "NNlib"] +git-tree-sha1 = "2f6efe2f76d57a0ee67cb6eff49b4d02fccbd175" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.1.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.21+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "8a9102cb805df46fc3d6effdc2917f09b0215c0b" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.10" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "7a1a306b72cfa60634f03a911405f4e64d1b718b" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.1.0" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.7" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "03170c1e8a94732c1d835ce4c5b904b4b52cba1c" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.7.8" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "f86b3a049e5d05227b10e15dbb315c5b90f14988" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.5.9" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.5.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.21" + +[[deps.StructArrays]] +deps = ["Adapt", "DataAPI", "StaticArraysCore", "Tables"] +git-tree-sha1 = "13237798b407150a6d2e2bce5d793d7d9576e99e" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.13" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+0" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.22" + +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "77fea79baa5b22aeda896a8d9c6445a74500a2c2" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.74" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+0" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "66cc604b9a27a660e25a54e408b4371123a186a6" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.49" + +[[deps.ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.2.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" From d2eef4707ee9ab59457ca70bb3803fade8a3c91c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 16 Nov 2022 13:30:13 -0500 Subject: [PATCH 2/2] new train function --- Manifest.toml | 702 ------------------------------------------ Project.toml | 2 + README.md | 2 +- src/Fluxperimental.jl | 3 + src/train.jl | 78 +++++ test/train.jl | 27 ++ 6 files changed, 111 insertions(+), 703 deletions(-) delete mode 100644 Manifest.toml create mode 100644 src/train.jl create mode 100644 test/train.jl diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index e53e048..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,702 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.9.0-DEV" -manifest_format = "2.0" -project_hash = "67c4978faee91f2c50a70573bb50512caec0c109" - -[[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra"] -git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.2.1" - -[[deps.Accessors]] -deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "eb7a1342ff77f4f9b6552605f27fd432745a53a3" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.22" - -[[deps.Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.4.0" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArrayInterface]] -deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"] -git-tree-sha1 = "d6173480145eb632d6571c148d94b9d3d773820e" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "6.0.23" - -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "c46fb7dd1d8ca1d213ba25848a5ec4e47a1a1b08" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.26" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.2.0" - -[[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.37" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "49549e2c28ffb9cc77b3689dc10e46e6271e9452" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.12.0" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "0c8c8887763f42583e1206ee35413a43c91e2623" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.45.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.6" - -[[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.4" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.3.0" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.1" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.4.1" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.DataAPI]] -git-tree-sha1 = "e08915633fcb3ea83bf9d6126292e5bc5c739922" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.13.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "19b1417ff479c07e523fcbf2fd735a3fde3d1ab3" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.0" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.12.2" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "c36550cb29cbe373e95b3f40486b9a4148f89ffd" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.2" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExprTools]] -git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.8" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.1" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "802bfc139833d2ba893dd9e62ba1767c88d708ae" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.5" - -[[deps.Flux]] -deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] -git-tree-sha1 = "66b62bf72c4b5d4904441ed0677eab53266033c7" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.7" - -[[deps.FoldsThreads]] -deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] -git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" -uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" -version = "0.1.1" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "10fa12fe96e4d76acfa738f4df2126589a67374f" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.33" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "a2657dd0f3e8a61dbe70fc7c122038bd33790af5" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.3.0" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "45d7deaf05cbb44116ba785d147c518ab46352d7" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.5.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.2" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "76f70a337a153c1632104af19d29023dbb6f30dd" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.16.6" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.7" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.8" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "e7e9184b0bf0158ac4e4aa9daf00041b5909bf1a" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.14.0" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.16+0" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "94d9c52ca447e23eac0c0f074effbcd38830deb5" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.18" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MLStyle]] -git-tree-sha1 = "060ef7956fef2dc06b0e63b294f7dbfbcbdc7ea2" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.16" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase", "Transducers"] -git-tree-sha1 = "824e9dfc7509cab1ec73ba77b55a916bb2905e26" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.11" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" - -[[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.3" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.2" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.NNlib]] -deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "00bcfcea7b2063807fdcab2e0ce86ef00b8b8000" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.10" - -[[deps.NNlibCUDA]] -deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "4429261364c5ea5b7308aecaa10e803ace101631" -uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.4" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.1" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "GPUArrays", "LinearAlgebra", "MLUtils", "NNlib"] -git-tree-sha1 = "2f6efe2f76d57a0ee67cb6eff49b4d02fccbd175" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.1.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+0" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "8a9102cb805df46fc3d6effdc2917f09b0215c0b" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.10" - -[[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.3.0" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "7a1a306b72cfa60634f03a911405f4e64d1b718b" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.6.0" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.0" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.1.7" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "03170c1e8a94732c1d835ce4c5b904b4b52cba1c" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.7.8" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "f86b3a049e5d05227b10e15dbb315c5b90f14988" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.9" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.0" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.5.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" - -[[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "StaticArraysCore", "Tables"] -git-tree-sha1 = "13237798b407150a6d2e2bce5d793d7d9576e99e" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.13" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+0" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.22" - -[[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "77fea79baa5b22aeda896a8d9c6445a74500a2c2" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.74" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "66cc604b9a27a660e25a54e408b4371123a186a6" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.49" - -[[deps.ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.2.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" diff --git a/Project.toml b/Project.toml index 59d91fe..a44f650 100644 --- a/Project.toml +++ b/Project.toml @@ -6,12 +6,14 @@ version = "0.1.0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Flux = "0.13.7" NNlib = "0.8.10" Optimisers = "0.2.10" +ProgressMeter = "1.7.2" Zygote = "0.6.49" julia = "1.6" diff --git a/README.md b/README.md index 6374cc5..3b7d68f 100644 --- a/README.md +++ b/README.md @@ -34,4 +34,4 @@ As will any features which migrate to Flux itself. ## Current Features * Layers `Split` and `Join` - +* A more advanced `train!` diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 047310a..26026d3 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -5,4 +5,7 @@ using Flux include("split_join.jl") export Split, Join +include("train.jl") +export shinkansen! + end # module Fluxperimental diff --git a/src/train.jl b/src/train.jl new file mode 100644 index 0000000..095f923 --- /dev/null +++ b/src/train.jl @@ -0,0 +1,78 @@ +using Flux: withgradient, DataLoader +using Optimisers: Optimisers +using ProgressMeter: ProgressMeter, Progress, next! + +#= + +This grew out of explicit-mode upgrade here: +https://github.com/FluxML/Flux.jl/pull/2082 + +=# + +""" + shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...]) + +This is a re-design of `train!`: + +* The loss function must accept the remaining arguments: `loss(model, data...)` +* The optimiser state from `setup` must be passed to the keyword `state`. + +By default it calls `gradient(loss, model, data...)` just like that. +Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times. + +But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`, +and uses that to generate smaller arrays to feed to `gradient`. +All other keywords are passed to `DataLoader`, e.g. to shuffle batches. + +Returns the loss from every call. + +# Example +``` +X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) +Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) + +model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) +state = Flux.setup(Adam(0.1, (0.7, 0.95)), model) +# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now + +shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y + Flux.logitcrossentropy(m(x), y) +end + +all((softmax(model(X)) .> 0.5) .== Y) +``` +""" +function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...) + if batchsize != nothing + loader = DataLoader(data; batchsize, kw...) + losses = Vector{Float32}[] + prog = Progress(length(loader) * epochs) + + for e in 1:epochs + eplosses = Float32[] + for (i,d) in enumerate(loader) + l, (g, _...) = withgradient(loss, model, d...) + isfinite(l) || error("loss is $l, on batch $i, epoch $epoch") + Optimisers.update!(state, model, g) + push!(eplosses, l) + next!(prog; showvalues=[(:epoch, e), (:loss, l)]) + end + push!(losses, eplosses) + end + + return allequal(size.(losses)) ? reduce(hcat, losses) : losses + else + losses = Float32[] + prog = Progress(epochs) + + for e in 1:epochs + l, (g, _...) = withgradient(loss, model, data...) + isfinite(l) || error("loss is $l, on epoch $epoch") + Optimisers.update!(state, model, g) + push!(losses, l) + next!(prog; showvalues=[(:epoch, epoch), (:loss, l)]) + end + + return losses + end +end diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000..8c0a710 --- /dev/null +++ b/test/train.jl @@ -0,0 +1,27 @@ +import Flux, Fluxperimental, Optimisers + +@testset "shinkansen!" begin + + X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) + Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) + + model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) + state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) + + Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y + Flux.logitcrossentropy(m(x), y) + end + + @test all((Flux.softmax(model(X)) .> 0.5) .== Y) + + model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) + state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) + + Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y + Flux.logitcrossentropy(m(x), y) + end + + @test all((Flux.softmax(model(X)) .> 0.5) .== Y) + +end +