From ee3e0b8a9835909d9b489faf685aee4928fa893d Mon Sep 17 00:00:00 2001 From: Essam Date: Wed, 22 May 2024 19:29:01 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20basic=20docs=20skeleton=20and?= =?UTF-8?q?=20README=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/Manifest.toml | 394 +++++++++++++++++- docs/Project.toml | 1 + docs/make.jl | 23 +- docs/src/api.md | 36 -- docs/src/assets/themes/documenter-light.css | 14 + docs/src/contributing.md | 15 + .../early.md => full tutorials/Boston.md} | 0 docs/src/full tutorials/MNIST.md | 96 +++++ docs/src/index.md | 26 +- docs/src/interface/Builders.md | 16 + docs/src/interface/Classification.md | 3 + docs/src/interface/Custom Builders.md | 61 +++ docs/src/interface/Image Classification.md | 3 + docs/src/interface/Multitarget Regression.md | 3 + docs/src/interface/Regression.md | 3 + docs/src/interface/Summary.md | 143 +++++++ .../Composition.md} | 0 docs/src/workflow examples/Early Stopping.md | 0 .../Hyperparameter Tuning.md | 0 .../workflow examples/Incremental Training.md | 55 +++ 20 files changed, 842 insertions(+), 50 deletions(-) delete mode 100644 docs/src/api.md rename docs/src/{features/early.md => full tutorials/Boston.md} (100%) create mode 100644 docs/src/full tutorials/MNIST.md create mode 100644 docs/src/interface/Builders.md create mode 100644 docs/src/interface/Classification.md create mode 100644 docs/src/interface/Custom Builders.md create mode 100644 docs/src/interface/Image Classification.md create mode 100644 docs/src/interface/Multitarget Regression.md create mode 100644 docs/src/interface/Regression.md create mode 100644 docs/src/interface/Summary.md rename docs/src/{features/tuning.md => workflow examples/Composition.md} (100%) create mode 100644 docs/src/workflow examples/Early Stopping.md create mode 100644 docs/src/workflow examples/Hyperparameter Tuning.md create mode 100644 docs/src/workflow examples/Incremental Training.md diff --git a/docs/Manifest.toml b/docs/Manifest.toml index d9bc7bb7..2ec5f8aa 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.0" manifest_format = "2.0" -project_hash = "8237dd01902c50351547fc838fcc6b6ea3cfb2cb" +project_hash = "760378c053aeb477e203dc95fb0a527c7911c8d1" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -53,6 +53,18 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + [[deps.BSON]] git-tree-sha1 = "4c3e506685c527ac6a54ccc0c8c76fd6f91b42fb" uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -86,11 +98,27 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" +[[deps.BitFlags]] +git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.8" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.5.0" +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" @@ -125,18 +153,52 @@ weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] ChainRulesCoreSparseArraysExt = "SparseArrays" +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.4" +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "4b270d6465eb21ae89b732182c20dc165f8bf9f2" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.25.0" + [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" version = "0.11.5" +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -174,6 +236,12 @@ git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "6cbbd4d241d7e6579ab354737f4dd95ca43946e1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.1" + [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" @@ -194,11 +262,28 @@ git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.3" +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + [[deps.DataAPI]] git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" @@ -264,6 +349,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" @@ -288,6 +379,12 @@ git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" version = "1.16.3" +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -363,6 +460,12 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + [[deps.Git]] deps = ["Git_jll"] git-tree-sha1 = "51764e6c2e84c37055e846c516e9015b4a291c7d" @@ -375,6 +478,11 @@ git-tree-sha1 = "d8be4aab0f4e043cc40984e9097417307cce4c03" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" version = "2.36.1+2" +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + [[deps.Gumbo]] deps = ["AbstractTrees", "Gumbo_jll", "Libdl"] git-tree-sha1 = "a1a138dfbf9df5bace489c7a9d5196d6afdfa140" @@ -387,6 +495,30 @@ git-tree-sha1 = "29070dee9df18d9565276d68a596854b1764aa38" uuid = "528830af-5a63-567c-a44a-034ed33b8444" version = "0.10.2+0" +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "4cc2bb72df6ff40b055295fdef6d92955f9dede8" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.12.2+2" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + [[deps.IOCapture]] deps = ["Logging", "Random"] git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" @@ -399,15 +531,50 @@ git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.14" +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" version = "0.3.1" +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.0" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -436,6 +603,18 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" @@ -459,19 +638,22 @@ deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" version = "6.6.3" +weakdeps = ["BFloat16s"] [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" - [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" version = "0.0.29+0" +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + [[deps.LazilyInitializedFields]] git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" @@ -481,6 +663,11 @@ version = "1.2.2" deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -537,6 +724,24 @@ version = "0.3.27" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "aab72207b3c687086a400be710650a57494992bd" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.14" + [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] git-tree-sha1 = "933cc8ec638bd6735c2a05a349f94eb75e59357c" @@ -562,12 +767,23 @@ git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" version = "0.4.4" +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.13" +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -578,6 +794,12 @@ git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" version = "0.1.2" +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" @@ -610,6 +832,12 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" @@ -632,6 +860,12 @@ version = "0.9.16" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" @@ -648,6 +882,15 @@ version = "0.1.5" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OffsetArrays]] +git-tree-sha1 = "e64b4f5ea6b7389f6f046d13d4896a8f9c1ba71e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.0" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + [[deps.OneHotArrays]] deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" @@ -670,6 +913,12 @@ git-tree-sha1 = "1b2f042897343a9dfdcc9366e4ecbd3d00780c49" uuid = "9bd350c2-7e96-507f-8002-3f2e150b4e1b" version = "8.9.0+1" +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "a12e56c72edee3ce6b96667745e6cbbe5498f200" @@ -698,6 +947,18 @@ deps = ["Artifacts", "Libdl"] uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" version = "10.42.0+1" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" @@ -710,11 +971,29 @@ git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" uuid = "570af359-4316-4cb7-8c74-252c00c2016b" version = "1.2.0" +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.10.0" +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -732,6 +1011,12 @@ git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" version = "0.2.0" +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.1" + [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -794,6 +1079,18 @@ git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" version = "3.0.0" +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "90b4f68892337554d31cdcdbe19e48989f26c7e6" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.3" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -808,6 +1105,11 @@ git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" version = "0.1.0" +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" @@ -850,6 +1152,12 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" @@ -889,6 +1197,30 @@ git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.34.3" +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" @@ -902,6 +1234,12 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -933,6 +1271,12 @@ deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -966,6 +1310,11 @@ version = "0.4.80" OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -973,6 +1322,26 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "dd260903fdabea27d9b6021689b3cd5401a57748" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.20.0" + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + + [deps.Unitful.weakdeps] + ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" @@ -984,6 +1353,23 @@ git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" version = "0.1.3" +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" diff --git a/docs/Project.toml b/docs/Project.toml index af94c237..e4b51035 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,4 +2,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" diff --git a/docs/make.jl b/docs/make.jl index b1fd3ca6..423e60d6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,6 +7,7 @@ DocMeta.setdocmeta!(MLJFlux, :DocTestSetup, :(using MLJFlux); recursive=true) makedocs( sitename = "MLJFlux", format = Documenter.HTML(; + collapselevel = 1, assets = [ "assets/favicon.ico", asset( @@ -23,10 +24,24 @@ makedocs( modules = [MLJFlux], warnonly = true, pages = ["Introduction" => "index.md", - "API"=> "api.md", - "Features" => Any[ - "Tuning"=>"features/tuning.md", - "Early Stopping"=>"features/early.md", + "Interface"=> Any[ + "Summary"=>"interface/Summary.md", + "Builders"=>"interface/Builders.md", + "Custom Builders"=>"interface/Custom Builders.md", + "Classification"=>"interface/Classification.md", + "Regression"=>"interface/Regression.md", + "Multi-Target Regression"=>"interface/Multitarget Regression.md", + "Image Classification"=>"interface/Image Classification.md", + ], + "Workflow Examples" => Any[ + "Incremental Training"=>"workflow examples/Incremental Training.md", + "Validation and Hyperparameter Tuning"=>"workflow examples/Hyperparameter Tuning.md", + "Early Stopping"=>"workflow examples/Early Stopping.md", + "Model Composition"=>"workflow examples/Composition.md", + ], + "Tutorials"=>Any[ + "MNIST Digits Classification"=>"full tutorials/MNIST.md", + "Boston House Prices Prediction"=>"full tutorials/Boston.md", ], "Contributing" => "contributing.md", "About" => "about.md"], diff --git a/docs/src/api.md b/docs/src/api.md deleted file mode 100644 index 5c2e5374..00000000 --- a/docs/src/api.md +++ /dev/null @@ -1,36 +0,0 @@ - - -```@docs -MLJFlux.ImageClassifier -``` - -```@docs -MLJFlux.NeuralNetworkClassifier -``` - -```@docs -MLJFlux.NeuralNetworkRegressor -``` - -```@docs -MLJFlux.MultitargetNeuralNetworkRegressor -``` - -```@docs -MLJFlux.Linear -``` - -```@docs -MLJFlux.Short -``` - -```@docs -MLJFlux.MLP -``` - -```@docs -MLJFlux.@builder -``` - - - diff --git a/docs/src/assets/themes/documenter-light.css b/docs/src/assets/themes/documenter-light.css index a6489b46..990fcc5b 100644 --- a/docs/src/assets/themes/documenter-light.css +++ b/docs/src/assets/themes/documenter-light.css @@ -11716,4 +11716,18 @@ code.hljs { .input.is-rounded, #documenter .docs-sidebar form.docs-search>input { margin: 1.5rem 0.0rem !important; +} + +th, td { + text-align: left !important; +} + +summary { + cursor: pointer; + margin: 1rem 0rem; + +} + +details { + margin-bottom: 1.5rem; } \ No newline at end of file diff --git a/docs/src/contributing.md b/docs/src/contributing.md index e69de29b..f4989cff 100644 --- a/docs/src/contributing.md +++ b/docs/src/contributing.md @@ -0,0 +1,15 @@ +### Adding new models to MLJFlux (advanced) + +This section is mainly for MLJFlux developers. It assumes familiarity +with the [MLJ model +API](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/) + +If one subtypes a new model type as either +`MLJFlux.MLJFluxProbabilistic` or `MLJFlux.MLJFluxDeterministic`, then +instead of defining new methods for `MLJModelInterface.fit` and +`MLJModelInterface.update` one can make use of fallbacks by +implementing the lower level methods `shape`, `build`, and +`fitresult`. See the [classifier source code](/src/classifier.jl) for +an example. + +One still needs to implement a new `predict` method. \ No newline at end of file diff --git a/docs/src/features/early.md b/docs/src/full tutorials/Boston.md similarity index 100% rename from docs/src/features/early.md rename to docs/src/full tutorials/Boston.md diff --git a/docs/src/full tutorials/MNIST.md b/docs/src/full tutorials/MNIST.md new file mode 100644 index 00000000..8c206201 --- /dev/null +++ b/docs/src/full tutorials/MNIST.md @@ -0,0 +1,96 @@ +## Image Classification Example +An expanded version of this example, with early stopping and +snapshots, is available [here](/examples/mnist). + +We define a builder that builds a chain with six alternating +convolution and max-pool layers, and a final dense layer, which we +apply to the MNIST image dataset. + +First we define a generic builder (working for any image size, color +or gray): + +```julia +using MLJ +using Flux +using MLDatasets + +# helper function +function flatten(x::AbstractArray) + return reshape(x, :, size(x)[end]) +end + +import MLJFlux +mutable struct MyConvBuilder + filter_size::Int + channels1::Int + channels2::Int + channels3::Int +end + +function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) + + k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 + + mod(k, 2) == 1 || error("`filter_size` must be odd. ") + + # padding to preserve image size on convolution: + p = div(k - 1, 2) + + front = Chain( + Conv((k, k), n_channels => c1, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c2 => c3, pad=(p, p), relu), + MaxPool((2 ,2)), + flatten) + d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first + return Chain(front, Dense(d, n_out)) +end +``` +Next, we load some of the MNIST data and check scientific types +conform to those is the table above: + +```julia +N = 500 +Xraw, yraw = MNIST.traindata(); +Xraw = Xraw[:,:,1:N]; +yraw = yraw[1:N]; + +scitype(Xraw) +``` +```julia +scitype(yraw) +``` + +Inputs should have element scitype `GrayImage`: + +```julia +X = coerce(Xraw, GrayImage); +``` + +For classifiers, target must have element scitype `<: Finite`: + +```julia +y = coerce(yraw, Multiclass); +``` + +Instantiating an image classifier model: + +```julia +ImageClassifier = @load ImageClassifier +clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), + epochs=10, + loss=Flux.crossentropy) +``` + +And evaluating the accuracy of the model on a 30% holdout set: + +```julia +mach = machine(clf, X, y) + +evaluate!(mach, + resampling=Holdout(rng=123, fraction_train=0.7), + operation=predict_mode, + measure=misclassification_rate) +``` diff --git a/docs/src/index.md b/docs/src/index.md index 34d6171a..fff0a884 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,7 +2,7 @@ A Julia package integrating deep learning Flux models with MLJ. -### Objectives +## Objectives - Provide a user-friendly and high-level interface to fundamental [Flux](https://fluxml.ai/Flux.jl/stable/) deep learning models while still being extensible by supporting custom models written with Flux @@ -15,7 +15,7 @@ A Julia package integrating deep learning Flux models with MLJ. Also note that MLJFlux is limited to training models only when all training data fits into memory, though it still supports automatic batching of data. -### Installation +## Installation ```julia import Pkg @@ -24,7 +24,7 @@ Pkg.add(["MLJ", "MLJFlux", "Flux"]) ``` You only need `Flux` if you need to build a custom architecture or experiment with different optimizers, loss functions and activations. -### Quick Start +## Quick Start First load and instantiate mode: ```@example using MLJ, Flux, MLJFlux @@ -41,7 +41,7 @@ clf = NeuralNetworkClassifier( optimiser=Flux.ADAM(0.01), batch_size=8, epochs=100, - acceleration=CUDALibs() + acceleration=CUDALibs() # For GPU support ) # 3. Wrap it in a machine in fit @@ -54,9 +54,23 @@ evaluate!(mach, resampling=cv, measure=accuracy) ``` As you can see we were able to use MLJ functionality (i.e., cross validation) with a Flux deep learning model. All arguments provided also have defaults. -Notice that we were also able to define the neural network in a high-level fashion by only specifying the number of neurons per each hidden layer and the activation function. Meanwhile, `MLJFlux` was able to infer the input and output layer as well as use a suitable default for the loss function and output activation given the classification task. +Notice that we were also able to define the neural network in a high-level fashion by only specifying the number of neurons per each hidden layer and the activation function. Meanwhile, `MLJFlux` was able to infer the input and output layer as well as use a suitable default for the loss function and output activation given the classification task. Notice as well that we did not need to implement a training or prediction loop as in `Flux`. -### Flux or MLJFlux? +## Basic idea + +As in the example above, any MLJFlux model has a `builder` hyperparameter, an object encoding +instructions for creating a neural network given the data that the +model eventually sees (e.g., the number of classes in a classification +problem). While each MLJ model has a simple default builder, users +may need to define custom builders to get optimal results, +and this will require familiarity with the [Flux +API](https://fluxml.ai/Flux.jl/stable/) for defining a neural network +chain. + +In the future MLJFlux may provide a larger assortment of canned +builders. Pull requests introducing new ones are most welcome. + +## Flux or MLJFlux? [Flux](https://fluxml.ai/Flux.jl/stable/) is a deep learning framework in Julia that comes with everything you need to build deep learning models (i.e., GPU support, automatic differentiation, layers, activations, losses, optimizers, etc.). [MLJFlux](https://github.com/FluxML/MLJFlux.jl) wraps models built with Flux which provides a more high-level interface for building and training such models. More importantly, it empowers Flux models by extending their support to many common machine learning workflows that are possible via MLJ such as: - **Estimating performance** of your model using a holdout set or other resampling strategy (e.g., cross-validation) as measured by one or more metrics (e.g., loss functions) that may not have been used in training diff --git a/docs/src/interface/Builders.md b/docs/src/interface/Builders.md new file mode 100644 index 00000000..ea7dd24c --- /dev/null +++ b/docs/src/interface/Builders.md @@ -0,0 +1,16 @@ + +```@docs +MLJFlux.Linear +``` + +```@docs +MLJFlux.Short +``` + +```@docs +MLJFlux.MLP +``` + +```@docs +MLJFlux.@builder +``` diff --git a/docs/src/interface/Classification.md b/docs/src/interface/Classification.md new file mode 100644 index 00000000..0491e8fc --- /dev/null +++ b/docs/src/interface/Classification.md @@ -0,0 +1,3 @@ +```@docs +MLJFlux.NeuralNetworkClassifier +``` \ No newline at end of file diff --git a/docs/src/interface/Custom Builders.md b/docs/src/interface/Custom Builders.md new file mode 100644 index 00000000..78b1ada2 --- /dev/null +++ b/docs/src/interface/Custom Builders.md @@ -0,0 +1,61 @@ +### Defining Custom Builders + +Following is an example defining a new builder for creating a simple +fully-connected neural network with two hidden layers, with `n1` nodes +in the first hidden layer, and `n2` nodes in the second, for use in +any of the first three models in Table 1. The definition includes one +mutable struct and one method: + +```julia +mutable struct MyBuilder <: MLJFlux.Builder + n1 :: Int + n2 :: Int +end + +function MLJFlux.build(nn::MyBuilder, rng, n_in, n_out) + init = Flux.glorot_uniform(rng) + return Chain(Dense(n_in, nn.n1, init=init), + Dense(nn.n1, nn.n2, init=init), + Dense(nn.n2, n_out, init=init)) +end +``` + +Note here that `n_in` and `n_out` depend on the size of the data (see +Table 1). + +For a concrete image classification example, see +[examples/mnist](examples/mnist). + +More generally, defining a new builder means defining a new struct +sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method +with one of these signatures: + +```julia +MLJFlux.build(builder::MyBuilder, rng, n_in, n_out) +MLJFlux.build(builder::MyBuilder, rng, n_in, n_out, n_channels) # for use with `ImageClassifier` +``` + +This method must return a `Flux.Chain` instance, `chain`, subject to the +following conditions: + +- `chain(x)` must make sense: + + - for any `x <: Array{<:AbstractFloat, 2}` of size `(n_in, + batch_size)` where `batch_size` is any integer (for use with one + of the first three model types); or + + - for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels, + batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and + `batch_size` is any integer (for use with `ImageClassifier`) + +- The object returned by `chain(x)` must be an `AbstractFloat` vector + of length `n_out`. + +Alternatively, use `MLJFlux.@builder(neural_net)` to automatically create a builder for +any valid Flux chain expression `neural_net`, where the symbols `n_in`, `n_out`, +`n_channels` and `rng` can appear literally, with the interpretations explained above. For +example, + +``` +builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh)) +``` \ No newline at end of file diff --git a/docs/src/interface/Image Classification.md b/docs/src/interface/Image Classification.md new file mode 100644 index 00000000..1af989a8 --- /dev/null +++ b/docs/src/interface/Image Classification.md @@ -0,0 +1,3 @@ +```@docs +MLJFlux.ImageClassifier +``` diff --git a/docs/src/interface/Multitarget Regression.md b/docs/src/interface/Multitarget Regression.md new file mode 100644 index 00000000..2257e9d4 --- /dev/null +++ b/docs/src/interface/Multitarget Regression.md @@ -0,0 +1,3 @@ +```@docs +MLJFlux.MultitargetNeuralNetworkRegressor +``` \ No newline at end of file diff --git a/docs/src/interface/Regression.md b/docs/src/interface/Regression.md new file mode 100644 index 00000000..f19f4b8e --- /dev/null +++ b/docs/src/interface/Regression.md @@ -0,0 +1,3 @@ +```@docs +MLJFlux.NeuralNetworkRegressor +``` \ No newline at end of file diff --git a/docs/src/interface/Summary.md b/docs/src/interface/Summary.md new file mode 100644 index 00000000..12a853c7 --- /dev/null +++ b/docs/src/interface/Summary.md @@ -0,0 +1,143 @@ +## Models + +MLJFlux provides four model types, for use with input features `X` and +targets `y` of the [scientific +type](https://alan-turing-institute.github.io/MLJScientificTypes.jl/dev/) +indicated in the table below. The parameters `n_in`, `n_out` and `n_channels` +refer to information passed to the builder, as described under +[Defining a new builder](defining-a-new-builder) below. + +Model Type | Prediction type | `scitype(X) <: _` | `scitype(y) <: _` +-----------|-----------------|---------------|---------------------------- +`NeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `AbstractVector{<:Continuous)` (`n_out = 1`) +`MultitargetNeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `<: Table(Continuous)` with `n_out` columns +`NeuralNetworkClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite}` with `n_out` classes +`ImageClassifier` | `Probabilistic` | `AbstractVector(<:Image{W,H})` with `n_in = (W, H)` | `AbstractVector{<:Finite}` with `n_out` classes + + +```@raw html +
See definition of "model" +``` +In MLJ a *model* is a mutable struct storing hyper-parameters for some +learning algorithm indicated by the model name, and that's all. In +particular, an MLJ model does not store learned parameters. + +*Warning:* In Flux the term "model" has another meaning. However, as all +Flux "models" used in MLJFLux are `Flux.Chain` objects, we call them +*chains*, and restrict use of "model" to models in the MLJ sense. +```@raw html +
+``` + +```@raw html +
Dealing with non-tabular input +``` +Any `AbstractMatrix{<:AbstractFloat}` object `Xmat` can be forced to +have scitype `Table(Continuous)` by replacing it with ` X = +MLJ.table(Xmat)`. Furthermore, this wrapping, and subsequent +unwrapping under the hood, will compile to a no-op. At present this +includes support for sparse matrix data, but the implementation has +not been optimized for sparse data at this time and so should be used +with caution. + +Instructions for coercing common image formats into some +`AbstractVector{<:Image}` are +[here](https://juliaai.github.io/ScientificTypes.jl/dev/#Type-coercion-for-image-data). +```@raw html +
+``` + +```@raw html +
Fitting and warm restarts +``` +MLJ machines cache state enabling the "warm restart" of model +training, as demonstrated in the incremental training example. In the case of MLJFlux +models, `fit!(mach)` will use a warm restart if: + +- only `model.epochs` has changed since the last call; or + +- only `model.epochs` or `model.optimiser` have changed since the last + call and `model.optimiser_changes_trigger_retraining == false` (the + default) (the "state" part of the optimiser is ignored in this + comparison). This allows one to dynamically modify learning rates, + for example. + +Here `model=mach.model` is the associated MLJ model. + +The warm restart feature makes it possible to apply early stopping +criteria, as defined in +[EarlyStopping.jl](https://github.com/ablaom/EarlyStopping.jl). For an +example, see [/examples/mnist/](/examples/mnist/). (Eventually, this +will be handled by an MLJ model wrapper for controlling arbitrary +iterative models.) +```@raw html +
+``` + + + +## Model Hyperparameters. + +All models share the following hyper-parameters: + +| Hyper-parameter | Description | Default | +|----------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------| +| `builder` | Default builder for models. | `MLJFlux.Linear(σ=Flux.relu)` (regressors) or `MLJFlux.Short(n_hidden=0, dropout=0.5, σ=Flux.σ)` (classifiers) | +| `optimiser` | The optimiser to use for training. | `Flux.ADAM()` | +| `loss` | The loss function used for training. | `Flux.mse` (regressors) and `Flux.crossentropy` (classifiers) | +| `n_epochs` | Number of epochs to train for. | `10` | +| `batch_size` | The batch size for the data. | `1` | +| `lambda` | The regularization strength. Range = [0, ∞). | `0` | +| `alpha` | The L2/L1 mix of regularization. Range = [0, 1]. | `0` | +| `rng` | The random number generator (RNG) passed to builders, for weight initialization, for example. Can be any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that is reset on every cold restart of model (machine) training. | `GLOBAL_RNG` | +| `acceleration` | Use `CUDALibs()` for training on GPU; default is `CPU1()`. | `CPU1()` | +| `optimiser_changes_trigger_retraining` | True if fitting an associated machine should trigger retraining from scratch whenever the optimiser changes. | `false` | + + +The classifiers have an additional hyperparameter `finaliser` (default += `Flux.softmax`) which is the operation applied to the unnormalized +output of the final layer to obtain probabilities (outputs summing to +one). Default = `Flux.softmax`. It should return a vector of the same +length as its input. + +!!! note "Loss Functions" + Currently, the loss function specified by `loss=...` is applied + internally by Flux and needs to conform to the Flux API. You cannot, + for example, supply one of MLJ's probabilistic loss functions, such as + `MLJ.cross_entropy` to one of the classifier constructors, although + you *should* use MLJ loss functions in MLJ meta-algorithms. + +```@raw html +
More on accelerated training with GPUs +``` +As in the table, when instantiating a model for training on a GPU, specify +`acceleration=CUDALibs()`, as in + +```julia +using MLJ +ImageClassifier = @load ImageClassifier +model = ImageClassifier(epochs=10, acceleration=CUDALibs()) +mach = machine(model, X, y) |> fit! +``` + +In this example, the data `X, y` is copied onto the GPU under the hood +on the call to `fit!` and cached for use in any warm restart (see +above). The Flux chain used in training is always copied back to the +CPU at then conclusion of `fit!`, and made available as +`fitted_params(mach)`. +```@raw html +
+``` + + +## Built-in builders + +As for the `builder` argument, the following builders are provided out-of-the-box: + +|Builder | Description | +|:-------------------------|:-----------------------------------------------------| +| `MLJFlux.MLP(hidden=(10,))` | General multi-layer perceptron | +| `MLJFlux.Short(n_hidden=0, dropout=0.5, σ=sigmoid)` | Fully connected network with one hidden layer and dropout| +| `MLJFlux.Linear(σ=relu)` | Vanilla linear network with no hidden layers and activation function `σ` | + +See the following sections to learn more about the interface for the builders and models. diff --git a/docs/src/features/tuning.md b/docs/src/workflow examples/Composition.md similarity index 100% rename from docs/src/features/tuning.md rename to docs/src/workflow examples/Composition.md diff --git a/docs/src/workflow examples/Early Stopping.md b/docs/src/workflow examples/Early Stopping.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/src/workflow examples/Hyperparameter Tuning.md b/docs/src/workflow examples/Hyperparameter Tuning.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/src/workflow examples/Incremental Training.md b/docs/src/workflow examples/Incremental Training.md new file mode 100644 index 00000000..9f07f808 --- /dev/null +++ b/docs/src/workflow examples/Incremental Training.md @@ -0,0 +1,55 @@ +#### Incremental training + +```julia +import Random.seed!; seed!(123) +mach = machine(clf, X, y) +fit!(mach) + +julia> training_loss = cross_entropy(predict(mach, X), y) |> mean +0.9064070459118777 + +# Increasing learning rate and adding iterations: +clf.optimiser.eta = clf.optimiser.eta * 2 +clf.epochs = clf.epochs + 5 + +julia> fit!(mach, verbosity=2) +[ Info: Updating Machine{NeuralNetworkClassifier{Short,…},…} @804. +[ Info: Loss is 0.8686 +[ Info: Loss is 0.8228 +[ Info: Loss is 0.7706 +[ Info: Loss is 0.7565 +[ Info: Loss is 0.7347 +Machine{NeuralNetworkClassifier{Short,…},…} @804 trained 2 times; caches data + args: + 1: Source @985 ⏎ `Table{AbstractVector{Continuous}}` + 2: Source @367 ⏎ `AbstractVector{Multiclass{3}}` + +julia> training_loss = cross_entropy(predict(mach, X), y) |> mean +0.7347092796453824 +``` + +#### Accessing the Flux chain (model) + +```julia +julia> fitted_params(mach).chain +Chain(Chain(Dense(4, 3, σ), Flux.Dropout{Float64}(0.5, false), Dense(3, 3)), softmax) +``` + +#### Evolution of out-of-sample performance + +```julia +r = range(clf, :epochs, lower=1, upper=200, scale=:log10) +curve = learning_curve(clf, X, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) +using Plots +plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + +``` + +![](examples/iris/iris_history.png) \ No newline at end of file