From a0914feeca766f0abc9717eab3ac5115a2ef2586 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 14:02:37 -0400 Subject: [PATCH 1/6] Use ReTestItems for parallel testing --- .JuliaFormatter.toml | 3 +- .buildkite/pipeline.yml | 2 + .github/workflows/CI.yml | 2 + ...lPreferences.toml => LocalPreferences.toml | 0 Project.toml | 53 ++++-- README.md | 3 +- docs/make.jl | 11 +- docs/pages.jl | 12 +- docs/src/index.md | 3 +- docs/src/tutorials/basic_mnist_deq.md | 15 +- docs/src/tutorials/reduced_dim_deq.md | 15 +- ...EquilibriumNetworksSciMLSensitivityExt.jl} | 12 +- ext/DeepEquilibriumNetworksZygoteExt.jl | 17 +- src/DeepEquilibriumNetworks.jl | 26 ++- src/layers.jl | 112 +++++------ src/utils.jl | 18 +- test/Project.toml | 26 --- test/layers.jl | 180 ------------------ test/layers_tests.jl | 180 ++++++++++++++++++ test/qa.jl | 7 - test/qa_tests.jl | 17 ++ test/runtests.jl | 8 +- test/{test_utils.jl => shared_testsetup.jl} | 24 ++- test/utils.jl | 38 ---- test/utils_tests.jl | 38 ++++ 25 files changed, 427 insertions(+), 395 deletions(-) rename test/LocalPreferences.toml => LocalPreferences.toml (100%) rename ext/{DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl => DeepEquilibriumNetworksSciMLSensitivityExt.jl} (61%) delete mode 100644 test/Project.toml delete mode 100644 test/layers.jl create mode 100644 test/layers_tests.jl delete mode 100644 test/qa.jl create mode 100644 test/qa_tests.jl rename test/{test_utils.jl => shared_testsetup.jl} (73%) delete mode 100644 test/utils.jl create mode 100644 test/utils_tests.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f632a90a..35969990 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -3,4 +3,5 @@ whitespace_in_kwargs = false format_docstrings = true separate_kwargs_with_semicolon = true format_markdown = true -annotate_untyped_fields_with_any = false \ No newline at end of file +annotate_untyped_fields_with_any = false +join_lines_based_on_source = false diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f1f11a31..eb6978be 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -54,5 +54,7 @@ steps: timeout_in_minutes: 240 env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA==" SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 43cb34a4..c7df47bf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -43,6 +43,8 @@ jobs: env: GROUP: "CPU" JULIA_NUM_THREADS: 12 + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/test/LocalPreferences.toml b/LocalPreferences.toml similarity index 100% rename from test/LocalPreferences.toml rename to LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 6baa53da..0436322f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,23 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.0.3" +version = "2.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -25,25 +25,54 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] +DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] DeepEquilibriumNetworksZygoteExt = "Zygote" [compat] -ADTypes = "0.2.5" +ADTypes = "0.2.5, 1" +Aqua = "0.8.7" ChainRulesCore = "1" +CommonSolve = "0.2.4" ConcreteStructs = "0.2" ConstructionBase = "1" DiffEqBase = "6.119" +ExplicitImports = "1.4.1" FastClosures = "0.3" -LinearAlgebra = "1" +Functors = "0.4.10" LinearSolve = "2.21.2" -Lux = "0.5.11" +Lux = "0.5.37" +LuxCUDA = "0.3.2" +LuxCore = "0.1.14" +LuxTestUtils = "0.1.15" +NLsolve = "4.5.1" +NonlinearSolve = "3.10.0" +OrdinaryDiffEq = "6.74.1" PrecompileTools = "1" -Random = "1" +Random = "1.10" +ReTestItems = "1.23.1" SciMLBase = "2" SciMLSensitivity = "7.43" -Statistics = "1" +StableRNGs = "1.0.2" +Statistics = "1.10" SteadyStateDiffEq = "2" -TruncatedStacktraces = "1.1" -Zygote = "0.6.67" -julia = "1.9" +Test = "1.10" +Zygote = "0.6.69" +julia = "1.10" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[targets] +test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"] diff --git a/README.md b/README.md index 3587db8c..876685a2 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,7 @@ Random.seed!(rng, seed) model = Chain(Dense(2 => 2), DeepEquilibriumNetwork( - Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), + Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)), NewtonRaphson())) gdev = gpu_device() diff --git a/docs/make.jl b/docs/make.jl index 62440167..3117b436 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear) include("pages.jl") -makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.", - modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true, +makedocs(; sitename="Deep Equilibrium Networks", + authors="Avik Pal et al.", + modules=[DeepEquilibriumNetworks], + clean=true, + doctest=true, + linkcheck=true, format=Documenter.HTML(; assets=["assets/favicon.ico"], canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"), - plugins=[bib], pages) + plugins=[bib], + pages) deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true) diff --git a/docs/pages.jl b/docs/pages.jl index ac42f48d..5a82ffc3 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,9 +1,3 @@ -pages = [ - "Home" => "index.md", - "Tutorials" => [ - "tutorials/basic_mnist_deq.md", - "tutorials/reduced_dim_deq.md" - ], - "API References" => "api.md", - "References" => "references.md" -] +pages = ["Home" => "index.md", + "Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"], + "API References" => "api.md", "References" => "references.md"] diff --git a/docs/src/index.md b/docs/src/index.md index 0cb693d7..2cb1befb 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,8 +26,7 @@ Random.seed!(rng, seed) model = Chain(Dense(2 => 2), DeepEquilibriumNetwork( - Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), + Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)), NewtonRaphson())) gdev = gpu_device() diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 9e1085a8..f473ae55 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -66,8 +66,7 @@ function construct_model(solver; model_type::Symbol=:deq) # The input layer of the DEQ deq_model = Chain( - Parallel(+, - Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), + Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())) @@ -79,11 +78,11 @@ function construct_model(solver; model_type::Symbol=:deq) init = missing end - deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, - linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork( + deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) - classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), - Dense(64, 10)) + classifier = Chain( + GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10)) model = Chain(; down, deq, classifier) @@ -132,8 +131,8 @@ function accuracy(model, data, ps, st) return total_correct / total end -function train_model(solver, model_type; data_train=zip(x_train, y_train), - data_test=zip(x_test, y_test)) +function train_model( + solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 0b00b9e1..9be703d2 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -53,8 +53,7 @@ function construct_model(solver; model_type::Symbol=:regdeq) down = Chain(FlattenLayer(), Dense(784 => 512, gelu)) # The input layer of the DEQ - deq_model = Chain(Parallel(+, - Dense(128 => 64, tanh), # Reduced dim of `128` + deq_model = Chain(Parallel(+, Dense(128 => 64, tanh), # Reduced dim of `128` Dense(512 => 64, tanh)), # Original dim of `512` Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128` @@ -65,12 +64,12 @@ function construct_model(solver; model_type::Symbol=:regdeq) else # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here # we are only using Zygote so this is fine. - init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)), - false))) + init = WrappedFunction(x -> Zygote.@ignore(fill!( + similar(x, 128, size(x, 2)), false))) end - deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, - linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork( + deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10)) @@ -121,8 +120,8 @@ function accuracy(model, data, ps, st) return total_correct / total end -function train_model(solver, model_type; data_train=zip(x_train, y_train), - data_test=zip(x_test, y_test)) +function train_model( + solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) diff --git a/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl similarity index 61% rename from ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl rename to ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl index 21cc34ca..fdc36591 100644 --- a/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl +++ b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl @@ -1,11 +1,13 @@ -module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt +module DeepEquilibriumNetworksSciMLSensitivityExt # Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity # to load this extension -using LinearSolve, SciMLBase, SciMLSensitivity -import DeepEquilibriumNetworks: __default_sensealg +using LinearSolve: SimpleGMRES +using SciMLBase: SteadyStateProblem, ODEProblem +using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP +using DeepEquilibriumNetworks: DEQs -@inline function __default_sensealg(prob::SteadyStateProblem) +@inline function DEQs.__default_sensealg(prob::SteadyStateProblem) # We want to avoid the cost for cache construction for linsolve = nothing # For small problems we should use concrete jacobian but we assume users want to solve # large problems with this package so we default to GMRES and avoid runtime dispatches @@ -13,6 +15,6 @@ import DeepEquilibriumNetworks: __default_sensealg linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3) return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP()) end -@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) +@inline DEQs.__default_sensealg(prob::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) end diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 56fd849c..7b848443 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -1,11 +1,18 @@ module DeepEquilibriumNetworksZygoteExt -using ADTypes, Statistics, Zygote -import DeepEquilibriumNetworks: __gaussian_like, __estimate_jacobian_trace +using ADTypes: AutoZygote +using FastClosures: @closure +using Statistics: mean +using Zygote: Zygote +using DeepEquilibriumNetworks: DEQs -function __estimate_jacobian_trace(::AutoZygote, model, ps, z, x, rng) - res, back = Zygote.pullback(u -> model((u, x), ps), z) - vjp_z = only(back(__gaussian_like(rng, res))) +@inline __tupleify(u) = @closure x -> (u, x) + +## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 +## FIXME: This will be broken in the new Lux release let's fix this +function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify, z) + vjp_z = only(back(DEQs.__gaussian_like(rng, res))) return mean(abs2, vjp_z) end diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index c7fedef5..abaccfbb 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -3,19 +3,25 @@ module DeepEquilibriumNetworks import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase, - Statistics, SteadyStateDiffEq - - import ChainRulesCore as CRC - import ConcreteStructs: @concrete - import ConstructionBase: constructorof - import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer - import SciMLBase: AbstractNonlinearAlgorithm, - AbstractODEAlgorithm, _unwrap_val, NonlinearSolution - import TruncatedStacktraces: @truncate_stacktrace + using ADTypes: AutoFiniteDiff + using ChainRulesCore: ChainRulesCore + using CommonSolve: solve + using ConcreteStructs: @concrete + using ConstructionBase: ConstructionBase + using DiffEqBase: DiffEqBase, AbsNormTerminationMode + using FastClosures: @closure + using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer, + StatefulLuxLayer, WrappedFunction + using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer + using Random: Random, AbstractRNG, randn! + using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm, + NonlinearSolution, ODESolution, ODEFunction, ODEProblem, + SteadyStateProblem, _unwrap_val + using SteadyStateDiffEq: DynamicSS, SSRootfind end # Useful Constants +const CRC = ChainRulesCore const DEQs = DeepEquilibriumNetworks include("layers.jl") diff --git a/src/layers.jl b/src/layers.jl index 935a3fa2..e20b179b 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -22,13 +22,13 @@ struct DeepEquilibriumSolution # This is intentionally left untyped to allow up original end -function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jacobian_loss, - nfe, original) +function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, + u0, residual, jacobian_loss, nfe, original) sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original) ∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7) function ∇DeepEquilibriumSolution(∂sol) - return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, ∂sol.jacobian_loss, - ∂sol.nfe, CRC.NoTangent()) + return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, + ∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent()) end return sol, ∇DeepEquilibriumSolution end @@ -39,10 +39,14 @@ end function Base.show(io::IO, sol::DeepEquilibriumSolution) println(io, "DeepEquilibriumSolution") - println(io, " * Initial Guess: ", sol.u0) - println(io, " * Steady State: ", sol.z_star) - println(io, " * Residual: ", sol.residual) - println(io, " * Jacobian Loss: ", sol.jacobian_loss) + println(io, " * Initial Guess: ", + sprint(print, sol.u0; context=(:compact => true, :limit => true))) + println(io, " * Steady State: ", + sprint(print, sol.z_star; context=(:compact => true, :limit => true))) + println(io, " * Residual: ", + sprint(print, sol.residual; context=(:compact => true, :limit => true))) + println(io, " * Jacobian Loss: ", + sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true))) print(io, " * NFE: ", sol.nfe) end @@ -56,11 +60,9 @@ end kwargs end -@truncate_stacktrace DeepEquilibriumNetwork 3 2 - const DEQ = DeepEquilibriumNetwork -constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} +ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} function Lux.initialstates(rng::AbstractRNG, deq::DEQ) rng = Lux.replicate(rng) @@ -76,17 +78,17 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) z, st = __get_initial_condition(deq, x, ps, st) repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth) - zˢᵗᵃʳ, st_ = repeated_model((z, x), ps.model, st.model) - model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st_) - resid = CRC.ignore_derivatives(zˢᵗᵃʳ .- model((zˢᵗᵃʳ, x), ps.model)) + z_star, st_ = repeated_model((z, x), ps.model, st.model) + model = StatefulLuxLayer(deq.model, ps.model, st_) + resid = CRC.ignore_derivatives(z_star .- model((z_star, x))) rng = Lux.replicate(st.rng) - jac_loss = __estimate_jacobian_trace(__getproperty(deq, Val(:jacobian_regularization)), - model, ps.model, zˢᵗᵃʳ, x, rng) + jac_loss = __estimate_jacobian_trace( + __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) - solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, resid, zero(eltype(x)), - _unwrap_val(st.fixed_depth), jac_loss) - res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)), + solution = DeepEquilibriumSolution( + z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss) + res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), __getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) @@ -95,7 +97,7 @@ end function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} z, st = __get_initial_condition(deq, x, ps, st) - model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st.model) + model = StatefulLuxLayer(deq.model, ps.model, st.model) dudt = @closure (u, p, t) -> begin # The type-assert is needed because of an upstream Lux issue with type stability of @@ -106,17 +108,18 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x)) alg = __normalize_alg(deq) - sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, reltol=1e-3, - termination_condition=AbsNormTerminationMode(), maxiters=32, deq.kwargs...) - zˢᵗᵃʳ = __get_steady_state(sol) + termination_condition = AbsNormTerminationMode(Base.Fix1(maximum, abs)) + sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, + reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...) + z_star = __get_steady_state(sol) rng = Lux.replicate(st.rng) - jac_loss = __estimate_jacobian_trace(__getproperty(deq, Val(:jacobian_regularization)), - model, ps.model, zˢᵗᵃʳ, x, rng) + jac_loss = __estimate_jacobian_trace( + __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) - solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, __getproperty(sol, Val(:resid)), jac_loss, - __get_nfe(sol), sol) - res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)), + solution = DeepEquilibriumSolution( + z_star, z, __getproperty(sol, Val(:resid)), jac_loss, __get_nfe(sol), sol) + res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), __getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) @@ -153,8 +156,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq julia> model = DeepEquilibriumNetwork( - Parallel(+, Dense(2, 2; use_bias=false), - Dense(2, 2; use_bias=false)), + Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)), VCABM3(); verbose=false) DeepEquilibriumNetwork( model = Parallel( @@ -178,8 +180,8 @@ julia> model(ones(Float32, 2, 1), ps, st); See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref). """ -function DeepEquilibriumNetwork(model, solver; init=missing, - jacobian_regularization=nothing, +function DeepEquilibriumNetwork( + model, solver; init=missing, jacobian_regularization=nothing, problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...) where {pType} model isa AbstractExplicitLayer || (model = Lux.transform(model)) @@ -190,8 +192,8 @@ function DeepEquilibriumNetwork(model, solver; init=missing, elseif !(init isa AbstractExplicitLayer) init = Lux.transform(init) end - return DeepEquilibriumNetwork{pType}(init, model, solver, jacobian_regularization, - kwargs) + return DeepEquilibriumNetwork{pType}( + init, model, solver, jacobian_regularization, kwargs) end """ @@ -236,10 +238,8 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref). julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve julia> main_layers = ( - Parallel(+, Dense(4 => 4, tanh; use_bias=false), - Dense(4 => 4, tanh; use_bias=false)), - Dense(3 => 3, tanh), Dense(2 => 2, tanh), - Dense(1 => 1, tanh)) + Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)), + Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh)) (Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast)) julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh); @@ -252,8 +252,8 @@ julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Den Dense(2 => 4, tanh_fast) Dense(2 => 1, tanh_fast) Dense(1 => 4, tanh_fast) NoOpLayer() -julia> model = MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, - NewtonRaphson(), ((4,), (3,), (2,), (1,))) +julia> model = MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,))) DeepEquilibriumNetwork( model = MultiScaleInputLayer{scales = 4}( model = Chain( @@ -314,9 +314,9 @@ julia> model(x, ps, st); ``` """ -function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, - post_fuse_layer::Union{Nothing, Tuple}, solver, scales; - jacobian_regularization=nothing, kwargs...) +function MultiScaleDeepEquilibriumNetwork( + main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, + solver, scales; jacobian_regularization=nothing, kwargs...) @assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models." l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) @@ -327,8 +327,8 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma if post_fuse_layer === nothing model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales) else - model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), - split_idxs, scales) + model = MultiScaleInputLayer( + Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales) end return DeepEquilibriumNetwork(model, solver; kwargs...) @@ -347,14 +347,14 @@ If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Ne function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...) init = Chain(Parallel(nothing, init...), __flatten_vcat) - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer, - solver, scales; init, kwargs...) + return MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...) end function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...) - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer, - args...; init=nothing, kwargs...) + return MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, post_fuse_layer, args...; init=nothing, kwargs...) end """ @@ -364,13 +364,13 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t `ODEProblem{false}`. """ function MultiScaleNeuralODE(args...; kwargs...) - return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., - problem_type=ODEProblem{false}) + return MultiScaleDeepEquilibriumNetwork( + args...; kwargs..., problem_type=ODEProblem{false}) end ## Generate Initial Condition -@inline function __get_initial_condition(deq::DEQ{pType, NoOpLayer}, x, ps, - st) where {pType} +@inline function __get_initial_condition( + deq::DEQ{pType, NoOpLayer}, x, ps, st) where {pType} zₓ = __zeros_init(__getproperty(deq.model, Val(:scales)), x) z, st_ = deq.model((zₓ, x), ps.model, st.model) return z, (; st..., model=st_) @@ -389,11 +389,11 @@ end scales end -constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} = MultiScaleInputLayer{N} +function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} + return MultiScaleInputLayer{N} +end Lux.display_name(::MultiScaleInputLayer{N}) where {N} = "MultiScaleInputLayer{scales = $N}" -@truncate_stacktrace MultiScaleInputLayer 1 2 - function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S} return MultiScaleInputLayer{length(S)}(model, split_idxs, scales) end diff --git a/src/utils.jl b/src/utils.jl index 8de5a5d0..dfc13210 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,8 +1,8 @@ -@generated function __split_and_reshape(x::AbstractMatrix, ::Val{idxs}, - ::Val{shapes}) where {idxs, shapes} +@generated function __split_and_reshape( + x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes} dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] varnames = map(_ -> gensym("x_view"), dims) - calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in 1:length(dims)] + calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in eachindex(dims)] return quote $(calls...) return tuple($(varnames...)) @@ -28,7 +28,7 @@ end function CRC.rrule(::typeof(__flatten_vcat), x) y = __flatten_vcat(x) project_x = CRC.ProjectTo(x) - function ∇__flatten_vcat(∂y) + ∇__flatten_vcat = @closure ∂y -> begin ∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent()) return CRC.NoTangent(), project_x(__split_and_reshape(∂y, x)) end @@ -52,7 +52,8 @@ end @inline __get_nfe(sol::ODESolution) = __get_nfe(sol.stats) @inline function __get_nfe(sol::NonlinearSolution) return ifelse(sol.stats === nothing, - ifelse(sol.original === nothing, -1, __get_nfe(sol.original)), __get_nfe(sol.stats)) + ifelse(sol.original === nothing, -1, __get_nfe(sol.original)), + __get_nfe(sol.stats)) end @inline __get_nfe(stats) = -1 @inline __get_nfe(stats::Union{SciMLBase.NLStats, SciMLBase.DEStats}) = stats.nf @@ -95,8 +96,9 @@ end CRC.@non_differentiable __gaussian_like(::Any...) # Jacobian Stabilization -function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng) - __f = u -> model((u, x), ps) +## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 +function __estimate_jacobian_trace(ad::AutoFiniteDiff, model, z, x, rng) + __f = @closure u -> model((u, x)) res = zero(eltype(x)) ϵ = cbrt(eps(typeof(res))) ϵ⁻¹ = inv(ϵ) @@ -117,4 +119,4 @@ function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng) return res end -__estimate_jacobian_trace(::Nothing, model, ps, z, x, rng) = zero(eltype(x)) +__estimate_jacobian_trace(::Nothing, model, z, x, rng) = zero(eltype(x)) diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 04cd6405..00000000 --- a/test/Project.toml +++ /dev/null @@ -1,26 +0,0 @@ -[deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -Aqua = "0.8" diff --git a/test/layers.jl b/test/layers.jl deleted file mode 100644 index 24dcf798..00000000 --- a/test/layers.jl +++ /dev/null @@ -1,180 +0,0 @@ -using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, - SciMLSensitivity, SciMLBase, Test - -include("test_utils.jl") - -function loss_function(model, x, ps, st) - y, st = model(x, ps, st) - l1 = y isa Tuple ? sum(Base.Fix1(sum, abs2), y) : sum(abs2, y) - l2 = st.solution.jacobian_loss - l3 = sum(abs2, st.solution.z_star .- st.solution.u0) - return l1 + l2 + l3 -end - -@testset "DeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = __get_prng(0) - - base_models = [ - Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)), - Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)) - ] - init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)] - x_sizes = [(2, 14), (3, 3, 1, 3)] - - model_type = (:deq, :skipdeq, :skipregdeq) - solvers = (VCAB3(), Tsit5(), - NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), - SimpleLimitedMemoryBroyden()) - jacobian_regularizations = Any[nothing, AutoZygote()] - !ongpu && push!(jacobian_regularizations, AutoFiniteDiff()) - - @testset "Solver: $(__nameof(solver))" for solver in solvers, - mtype in model_type, jacobian_regularization in jacobian_regularizations - - @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( - base_models, - init_models, x_sizes) - model = if mtype === :deq - DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) - elseif mtype === :skipdeq - SkipDeepEquilibriumNetwork(base_model, init_model, solver; - jacobian_regularization) - elseif mtype === :skipregdeq - SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization) - end - - ps, st = Lux.setup(rng, model) |> dev - @test st.solution == DeepEquilibriumSolution() - - x = randn(rng, Float32, x_size...) |> dev - z, st = model(x, ps, st) - - opt_broken = solver isa SimpleLimitedMemoryBroyden - @jet model(x, ps, st) opt_broken=opt_broken - - @test all(isfinite, z) - @test size(z) == size(x) - @test st.solution isa DeepEquilibriumSolution - @test maximum(abs, st.solution.residual) ≤ 1e-3 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - - ps, st = Lux.setup(rng, model) |> dev - st = Lux.update_state(st, :fixed_depth, Val(10)) - @test st.solution == DeepEquilibriumSolution() - - z, st = model(x, ps, st) - @jet model(x, ps, st) - - @test all(isfinite, z) - @test size(z) == size(x) - @test st.solution isa DeepEquilibriumSolution - @test st.solution.nfe == 10 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - end - end -end - -@testset "MultiScaleDeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = __get_prng(0) - - main_layers = [ - (Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)), - __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), - __get_dense_layer(1 => 1)) - ] - - mapping_layers = [ - [NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1); - __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); - __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); - __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()] - ] - - init_layers = [ - (__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), __get_dense_layer(4 => 2), - __get_dense_layer(4 => 1)) - ] - - x_sizes = [(4, 3)] - scales = [((4,), (3,), (2,), (1,))] - - model_type = (:deq, :skipdeq, :skipregdeq, :node) - solvers = (VCAB3(), Tsit5(), - NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), - SimpleLimitedMemoryBroyden()) - jacobian_regularizations = (nothing,) - - for mtype in model_type, jacobian_regularization in jacobian_regularizations - @testset "Solver: $(__nameof(solver))" for solver in solvers - @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( - main_layers, - mapping_layers, init_layers, x_sizes, scales) - model = if mtype === :deq - MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, - solver, scale; jacobian_regularization) - elseif mtype === :skipdeq - MultiScaleSkipDeepEquilibriumNetwork( - main_layer, mapping_layer, nothing, - init_layer, solver, scale; jacobian_regularization) - elseif mtype === :skipregdeq - MultiScaleSkipDeepEquilibriumNetwork( - main_layer, mapping_layer, nothing, - solver, scale; jacobian_regularization) - elseif mtype === :node - solver isa SciMLBase.AbstractODEAlgorithm || continue - MultiScaleNeuralODE(main_layer, mapping_layer, nothing, solver, scale; - jacobian_regularization) - end - - ps, st = Lux.setup(rng, model) |> dev - @test st.solution == DeepEquilibriumSolution() - - x = randn(rng, Float32, x_size...) |> dev - z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) - - opt_broken = solver isa SimpleLimitedMemoryBroyden - @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch - - @test all(isfinite, z_) - @test size(z_) == (sum(prod, scale), size(x, ndims(x))) - @test st.solution isa DeepEquilibriumSolution - if st.solution.residual !== nothing - @test maximum(abs, st.solution.residual) ≤ 1e-3 - end - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - - ps, st = Lux.setup(rng, model) |> dev - st = Lux.update_state(st, :fixed_depth, Val(10)) - @test st.solution == DeepEquilibriumSolution() - - z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) - opt_broken = jacobian_regularization isa AutoZygote - @jet model(x, ps, st) opt_broken=opt_broken - - @test all(isfinite, z_) - @test size(z_) == (sum(prod, scale), size(x, ndims(x))) - @test st.solution isa DeepEquilibriumSolution - @test st.solution.nfe == 10 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - end - end - end -end diff --git a/test/layers_tests.jl b/test/layers_tests.jl new file mode 100644 index 00000000..75b6f68d --- /dev/null +++ b/test/layers_tests.jl @@ -0,0 +1,180 @@ +@testsetup module LayersTestSetup + +using NonlinearSolve, OrdinaryDiffEq + +function loss_function(model, x, ps, st) + y, st = model(x, ps, st) + l1 = y isa Tuple ? sum(Base.Fix1(sum, abs2), y) : sum(abs2, y) + l2 = st.solution.jacobian_loss + l3 = sum(abs2, st.solution.z_star .- st.solution.u0) + return l1 + l2 + l3 +end + +SOLVERS = (VCAB3(), Tsit5(), NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), + SimpleLimitedMemoryBroyden()) + +export loss_function, SOLVERS + +end + +@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] begin + using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote + + rng = __get_prng(0) + + base_models = [Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)), + Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))] + init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)] + x_sizes = [(2, 14), (3, 3, 1, 3)] + + model_type = (:deq, :skipdeq, :skipregdeq) + _jacobian_regularizations = (nothing, AutoZygote(), AutoFiniteDiff()) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : + _jacobian_regularizations + + @testset "Solver: $(__nameof(solver))" for solver in SOLVERS, + mtype in model_type, + jacobian_regularization in jacobian_regularizations + + @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( + base_models, init_models, x_sizes) + model = if mtype === :deq + DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) + elseif mtype === :skipdeq + SkipDeepEquilibriumNetwork( + base_model, init_model, solver; jacobian_regularization) + elseif mtype === :skipregdeq + SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization) + end + + ps, st = Lux.setup(rng, model) |> dev + @test st.solution == DeepEquilibriumSolution() + + x = randn(rng, Float32, x_size...) |> dev + z, st = model(x, ps, st) + + opt_broken = solver isa SimpleLimitedMemoryBroyden + @jet model(x, ps, st) opt_broken=opt_broken + + @test all(isfinite, z) + @test size(z) == size(x) + @test st.solution isa DeepEquilibriumSolution + @test maximum(abs, st.solution.residual) ≤ 1e-3 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + + ps, st = Lux.setup(rng, model) |> dev + st = Lux.update_state(st, :fixed_depth, Val(10)) + @test st.solution == DeepEquilibriumSolution() + + z, st = model(x, ps, st) + @jet model(x, ps, st) + + @test all(isfinite, z) + @test size(z) == size(x) + @test st.solution isa DeepEquilibriumSolution + @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + end + end + end +end + +@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] begin + using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote + + rng = __get_prng(0) + + main_layers = [(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)), + __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), __get_dense_layer(1 => 1))] + + mapping_layers = [[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1); + __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); + __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); + __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]] + + init_layers = [(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), + __get_dense_layer(4 => 2), __get_dense_layer(4 => 1))] + + x_sizes = [(4, 3)] + scales = [((4,), (3,), (2,), (1,))] + + model_type = (:deq, :skipdeq, :skipregdeq, :node) + jacobian_regularizations = (nothing,) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset "Solver: $(__nameof(solver))" for solver in SOLVERS, + mtype in model_type, + jacobian_regularization in jacobian_regularizations + + @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( + main_layers, mapping_layers, init_layers, x_sizes, scales) + model = if mtype === :deq + MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + elseif mtype === :skipdeq + MultiScaleSkipDeepEquilibriumNetwork( + main_layer, mapping_layer, nothing, init_layer, + solver, scale; jacobian_regularization) + elseif mtype === :skipregdeq + MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + elseif mtype === :node + solver isa SciMLBase.AbstractODEAlgorithm || continue + MultiScaleNeuralODE(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + end + + ps, st = Lux.setup(rng, model) |> dev + @test st.solution == DeepEquilibriumSolution() + + x = randn(rng, Float32, x_size...) |> dev + z, st = model(x, ps, st) + z_ = DEQs.__flatten_vcat(z) + + opt_broken = solver isa SimpleLimitedMemoryBroyden + @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch + + @test all(isfinite, z_) + @test size(z_) == (sum(prod, scale), size(x, ndims(x))) + @test st.solution isa DeepEquilibriumSolution + if st.solution.residual !== nothing + @test maximum(abs, st.solution.residual) ≤ 1e-3 + end + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + + ps, st = Lux.setup(rng, model) |> dev + st = Lux.update_state(st, :fixed_depth, Val(10)) + @test st.solution == DeepEquilibriumSolution() + + z, st = model(x, ps, st) + z_ = DEQs.__flatten_vcat(z) + opt_broken = jacobian_regularization isa AutoZygote + @jet model(x, ps, st) opt_broken=opt_broken + + @test all(isfinite, z_) + @test size(z_) == (sum(prod, scale), size(x, ndims(x))) + @test st.solution isa DeepEquilibriumSolution + @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + end + end + end +end diff --git a/test/qa.jl b/test/qa.jl deleted file mode 100644 index 94eb43f6..00000000 --- a/test/qa.jl +++ /dev/null @@ -1,7 +0,0 @@ -using DeepEquilibriumNetworks, Aqua, Test -import ChainRulesCore as CRC - -@testset "Aqua" begin - Aqua.test_all(DeepEquilibriumNetworks; ambiguities=false) - Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false) -end diff --git a/test/qa_tests.jl b/test/qa_tests.jl new file mode 100644 index 00000000..2dd1d11e --- /dev/null +++ b/test/qa_tests.jl @@ -0,0 +1,17 @@ +@testitem "Aqua" begin + using Aqua + + Aqua.test_all(DeepEquilibriumNetworks; ambiguities=false) + Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false) +end + +@testitem "ExplicitImports" begin + import SciMLSensitivity, Zygote + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(DeepEquilibriumNetworks) === nothing + ## AbstractRNG seems to be a spurious detection in LuxFluxExt + @test check_no_stale_explicit_imports(DeepEquilibriumNetworks) === nothing +end diff --git a/test/runtests.jl b/test/runtests.jl index 045828fa..8ba7978a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,3 @@ -using SafeTestsets, Test, TestSetExtensions +using ReTestItems -@testset ExtendedTestSet "Deep Equilibrium Networks" begin - @safetestset "Quality Assurance" include("qa.jl") - @safetestset "Utilities" include("utils.jl") - @safetestset "Layers" include("layers.jl") -end +ReTestItems.runtests(@__DIR__) diff --git a/test/test_utils.jl b/test/shared_testsetup.jl similarity index 73% rename from test/test_utils.jl rename to test/shared_testsetup.jl index b9268716..b22de31b 100644 --- a/test/test_utils.jl +++ b/test/shared_testsetup.jl @@ -1,3 +1,5 @@ +@testsetup module SharedTestSetup + using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote import LuxTestUtils: @jet using LuxCUDA @@ -35,15 +37,19 @@ const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" cuda_testing() = LuxCUDA.functional() && (GROUP == "All" || GROUP == "CUDA") -if !@isdefined(MODES) - const MODES = begin - cpu_mode = ("CPU", Array, LuxCPUDevice(), false) - cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) +const MODES = begin + cpu_mode = ("CPU", Array, LuxCPUDevice(), false) + cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) - modes = [] - cpu_testing() && push!(modes, cpu_mode) - cuda_testing() && push!(modes, cuda_mode) + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + + modes +end + +export Lux, LuxCore, LuxLib +export MODES, __get_dense_layer, __get_conv_layer, __is_finite_gradient, __get_prng, + __nameof, @jet - modes - end end diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 2c0057a3..00000000 --- a/test/utils.jl +++ /dev/null @@ -1,38 +0,0 @@ -using DeepEquilibriumNetworks, LinearAlgebra, SciMLBase, Test - -include("test_utils.jl") - -@testset "split_and_reshape: $mode" for (mode, aType, dev, ongpu) in MODES - x1 = ones(Float32, 4, 4) |> aType - x2 = fill(0.5f0, 2, 4) |> aType - x3 = zeros(Float32, 1, 4) |> aType - - x = vcat(x1, x2, x3) - split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1)))) - shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1))) - x_split = DEQs.__split_and_reshape(x, split_idxs, shapes) - - @test x1 == x_split[1] - @test x2 == x_split[2] - @test x3 == x_split[3] - - @jet DEQs.__split_and_reshape(x, split_idxs, shapes) -end - -@testset "unrolled_mode check" begin - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(10))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(0))) - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(10)))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(0)))) -end - -@testset "get unrolled_mode" begin - @test DEQs.__get_unrolled_depth(Val(10)) == 10 - @test DEQs.__get_unrolled_depth((; fixed_depth=Val(10))) == 10 -end - -@testset "deep equilibrium solution" begin - sol = @test_nowarn DeepEquilibriumSolution(randn(10), randn(10), randn(10), 0.4, 10, - nothing) - @test_nowarn println(sol) -end diff --git a/test/utils_tests.jl b/test/utils_tests.jl new file mode 100644 index 00000000..2d114a79 --- /dev/null +++ b/test/utils_tests.jl @@ -0,0 +1,38 @@ +@testitem "split_and_reshape" setup=[SharedTestSetup] begin + for (mode, aType, dev, ongpu) in MODES + x1 = ones(Float32, 4, 4) |> aType + x2 = fill(0.5f0, 2, 4) |> aType + x3 = zeros(Float32, 1, 4) |> aType + + x = vcat(x1, x2, x3) + split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1)))) + shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1))) + x_split = DEQs.__split_and_reshape(x, split_idxs, shapes) + + @test x1 == x_split[1] + @test x2 == x_split[2] + @test x3 == x_split[3] + + @jet DEQs.__split_and_reshape(x, split_idxs, shapes) + end +end + +@testitem "unrolled_mode check" setup=[SharedTestSetup] begin + using SciMLBase + + @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(10))) + @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(0))) + @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(10)))) + @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(0)))) +end + +@testitem "get unrolled_mode" setup=[SharedTestSetup] begin + @test DEQs.__get_unrolled_depth(Val(10)) == 10 + @test DEQs.__get_unrolled_depth((; fixed_depth=Val(10))) == 10 +end + +@testitem "deep equilibrium solution" setup=[SharedTestSetup] begin + sol = @test_nowarn DeepEquilibriumSolution( + randn(10), randn(10), randn(10), 0.4, 10, nothing) + @test_nowarn println(sol) +end From 29e971e8b3676fa3b582d5331f2578d619363a88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 15:20:21 -0400 Subject: [PATCH 2/6] Add the manifest --- Manifest.toml | 1041 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1041 insertions(+) create mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 00000000..142f0567 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,1041 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.2" +manifest_format = "2.0" +project_hash = "df8a9208b4276382055ff54a66a4252730918e13" + +[[deps.ADTypes]] +git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "1.0.0" +weakdeps = ["ChainRulesCore", "EnzymeCore"] + + [deps.ADTypes.extensions] + ADTypesChainRulesCoreExt = "ChainRulesCore" + ADTypesEnzymeCoreExt = "EnzymeCore" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "c0d491ef0b135fd7d63cbc6404286bc633329425" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.36" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "133a240faec6e074e07c31ee75619c90544179cf" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.10.0" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceCUDSSExt = "CUDSS" + ArrayInterfaceChainRulesExt = "ChainRules" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceReverseDiffExt = "ReverseDiff" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.5" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] +git-tree-sha1 = "601f7e7b3d36f18790e2caf83a882d88e9b71ff1" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.2.4" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.23.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.CloseOpenIntervals]] +deps = ["Static", "StaticArrayInterface"] +git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.12" + +[[deps.CommonSolve]] +git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.4" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.14.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.0+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcreteStructs]] +git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" +uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +version = "0.2.3" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.5" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffEqBase]] +deps = ["ArrayInterface", "ConcreteStructs", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "FastClosures", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"] +git-tree-sha1 = "531c53fd0405716712a8b4960216c3b7b5ec89b9" +uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" +version = "6.149.1" + + [deps.DiffEqBase.extensions] + DiffEqBaseChainRulesCoreExt = "ChainRulesCore" + DiffEqBaseDistributionsExt = "Distributions" + DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] + DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated" + DiffEqBaseMPIExt = "MPI" + DiffEqBaseMeasurementsExt = "Measurements" + DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" + DiffEqBaseReverseDiffExt = "ReverseDiff" + DiffEqBaseTrackerExt = "Tracker" + DiffEqBaseUnitfulExt = "Unitful" + + [deps.DiffEqBase.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.DiffEqCallbacks]] +deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"] +git-tree-sha1 = "ee954c8b9d348b7a8a6aec5f28288bf5adecd4ee" +uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" +version = "2.37.0" + + [deps.DiffEqCallbacks.weakdeps] + OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" + Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.EnzymeCore]] +git-tree-sha1 = "18394bc78ac2814ff38fe5e0c9dc2cd171e2810c" +uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" +version = "0.7.2" +weakdeps = ["Adapt"] + + [deps.EnzymeCore.extensions] + AdaptExt = "Adapt" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FastBroadcast]] +deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] +git-tree-sha1 = "a6e756a880fc419c8b41592010aebe6a5ce09136" +uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +version = "0.2.8" + +[[deps.FastClosures]] +git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" +uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +version = "0.3.2" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "2de436b72c3422940cbe1367611d137008af7ec3" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.23.1" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.FunctionWrappersWrappers]] +deps = ["FunctionWrappers"] +git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" +uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" +version = "0.1.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.10" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "896385798a8d49a255c398bd49162062e4a4c435" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.13" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.18" +weakdeps = ["EnzymeCore"] + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "6.6.3" + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + + [deps.LLVM.weakdeps] + BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.29+0" + +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "62edfee3211981241b57ff1cedf4d74d79519277" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.15" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.2.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.27" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.Lux]] +deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"] +git-tree-sha1 = "d7f49df9abfbb372fcbde5f41e547aa3679e9793" +repo-rev = "ap/nested_ad" +repo-url = "https://github.com/LuxDL/Lux.jl.git" +uuid = "b2108857-7c20-44ae-9111-449ecde12c47" +version = "0.5.38" + + [deps.Lux.extensions] + LuxComponentArraysExt = "ComponentArrays" + LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] + LuxDynamicExpressionsExt = "DynamicExpressions" + LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"] + LuxFluxExt = "Flux" + LuxForwardDiffExt = "ForwardDiff" + LuxLuxAMDGPUExt = "LuxAMDGPU" + LuxMLUtilsExt = "MLUtils" + LuxMPIExt = "MPI" + LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] + LuxOptimisersExt = "Optimisers" + LuxReverseDiffExt = "ReverseDiff" + LuxSimpleChainsExt = "SimpleChains" + LuxTrackerExt = "Tracker" + LuxZygoteExt = "Zygote" + + [deps.Lux.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" + Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" + Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.LuxCore]] +deps = ["FastClosures", "Functors", "Random", "Setfield"] +git-tree-sha1 = "f799f3aa8599f79ed5e2c9fbaf74907c1ebe15ce" +uuid = "bb33d45b-7691-41d6-9220-0943567d0623" +version = "0.1.14" + +[[deps.LuxDeviceUtils]] +deps = ["Adapt", "ChainRulesCore", "FastClosures", "Functors", "LuxCore", "PrecompileTools", "Preferences", "Random"] +git-tree-sha1 = "bbcf12d598b8ef6d2b12e506b1d18125552c3b27" +uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" +version = "0.1.20" + + [deps.LuxDeviceUtils.extensions] + LuxDeviceUtilsAMDGPUExt = "AMDGPU" + LuxDeviceUtilsCUDAExt = "CUDA" + LuxDeviceUtilsFillArraysExt = "FillArrays" + LuxDeviceUtilsGPUArraysExt = "GPUArrays" + LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" + LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" + LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] + LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" + LuxDeviceUtilsSparseArraysExt = "SparseArrays" + LuxDeviceUtilsZygoteExt = "Zygote" + + [deps.LuxDeviceUtils.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.LuxLib]] +deps = ["ChainRulesCore", "FastClosures", "KernelAbstractions", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "b1f81a8aa8313c1f1b4cbfb18733db17c023427e" +uuid = "82251201-b29d-42c6-8e01-566dec8acb11" +version = "0.3.14" + + [deps.LuxLib.extensions] + LuxLibForwardDiffExt = "ForwardDiff" + LuxLibReverseDiffExt = "ReverseDiff" + LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] + LuxLibTrackerExt = "Tracker" + LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] + LuxLibcuDNNExt = ["CUDA", "cuDNN"] + + [deps.LuxLib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.MuladdMacro]] +git-tree-sha1 = "cac9cc5499c25554cba55cd3c30543cff5ca4fab" +uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +version = "0.2.4" + +[[deps.NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.3" + +[[deps.NLsolve]] +deps = ["Distances", "LineSearches", "LinearAlgebra", "NLSolversBase", "Printf", "Reexport"] +git-tree-sha1 = "019f12e9a1a7880459d0173c182e6a99365d7ac1" +uuid = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +version = "4.5.1" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "5055845dd316575ae2fc1f6dcb3545ff15fe547a" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.14" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.PartialFunctions]] +deps = ["MacroTools"] +git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" +uuid = "570af359-4316-4cb7-8c74-252c00c2016b" +version = "1.2.0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.Polyester]] +deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Requires", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] +git-tree-sha1 = "2ba5f33cbb51a85ef58a850749492b08f9bf2193" +uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" +version = "0.7.13" + +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.2.1" + +[[deps.PreallocationTools]] +deps = ["Adapt", "ArrayInterface", "ForwardDiff"] +git-tree-sha1 = "a660e9daab5db07adf3dedfe09b435cc530d855e" +uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +version = "0.4.21" + + [deps.PreallocationTools.extensions] + PreallocationToolsReverseDiffExt = "ReverseDiff" + + [deps.PreallocationTools.weakdeps] + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "d8f131090f2e44b145084928856a561c83f43b27" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "3.13.0" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" + RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "04c968137612c4a5629fa531334bb81ad5680f00" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.13" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.SciMLBase]] +deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "beb1f94b08c4976ed1db0ca01b9e6bac89706faf" +uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +version = "2.35.0" + + [deps.SciMLBase.extensions] + SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMakieExt = "Makie" + SciMLBasePartialFunctionsExt = "PartialFunctions" + SciMLBasePyCallExt = "PyCall" + SciMLBasePythonCallExt = "PythonCall" + SciMLBaseRCallExt = "RCall" + SciMLBaseZygoteExt = "Zygote" + + [deps.SciMLBase.weakdeps] + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" + PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" + PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + RCall = "6f49c342-dc21-5d91-9882-a32aef131414" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.SciMLOperators]] +deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools", "Setfield", "SparseArrays", "StaticArraysCore"] +git-tree-sha1 = "10499f619ef6e890f3f4a38914481cc868689cd5" +uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +version = "0.3.8" + +[[deps.SciMLStructures]] +git-tree-sha1 = "5833c10ce83d690c124beedfe5f621b50b02ba4d" +uuid = "53ae85a6-f571-4167-b2af-e1d143709226" +version = "1.1.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.8.10" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Requires", "SparseArrays", "Static", "SuiteSparse"] +git-tree-sha1 = "5d66818a39bb04bf328e92bc933ec5b4ee88e436" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.5.0" + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + + [deps.StaticArrayInterface.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.3" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.SteadyStateDiffEq]] +deps = ["ConcreteStructs", "DiffEqBase", "DiffEqCallbacks", "LinearAlgebra", "Reexport", "SciMLBase"] +git-tree-sha1 = "a735fd5053724cf4de31c81b4e2cc429db844be5" +uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +version = "2.0.1" + +[[deps.StrideArraysCore]] +deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] +git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" +uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" +version = "0.5.6" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.SymbolicIndexingInterface]] +deps = ["Accessors", "ArrayInterface", "MacroTools", "RuntimeGeneratedFunctions", "StaticArraysCore"] +git-tree-sha1 = "40ea524431a92328cd73582d1820a5b08247a40f" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.3.16" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.1" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.2" + +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + +[[deps.TruncatedStacktraces]] +deps = ["InteractiveUtils", "MacroTools", "Preferences"] +git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" +uuid = "781d530d-4396-4725-bb49-402e4bee1e77" +version = "1.4.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.3" + +[[deps.WeightInitializers]] +deps = ["ChainRulesCore", "LinearAlgebra", "PartialFunctions", "PrecompileTools", "Random", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "f0e6760ef9d22f043710289ddf29e4a4048c4822" +uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +version = "0.1.7" + + [deps.WeightInitializers.extensions] + WeightInitializersCUDAExt = "CUDA" + + [deps.WeightInitializers.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" From bb61c5f1aa24087a3ae979490c1ccc444c018d2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 15:20:54 -0400 Subject: [PATCH 3/6] Faster Nested AD --- Manifest.toml | 38 +++++++++++++++++-- Project.toml | 4 +- ext/DeepEquilibriumNetworksZygoteExt.jl | 49 ++++++++++++++++++++++--- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 142f0567..c4d00d25 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.Lux]] deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"] -git-tree-sha1 = "d7f49df9abfbb372fcbde5f41e547aa3679e9793" +git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43" repo-rev = "ap/nested_ad" repo-url = "https://github.com/LuxDL/Lux.jl.git" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -573,12 +573,13 @@ version = "0.1.20" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [[deps.LuxLib]] -deps = ["ChainRulesCore", "FastClosures", "KernelAbstractions", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "b1f81a8aa8313c1f1b4cbfb18733db17c023427e" +deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] +git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.14" +version = "0.3.15" [deps.LuxLib.extensions] + LuxLibAMDGPUExt = "AMDGPU" LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -684,6 +685,12 @@ git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.3" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -927,6 +934,24 @@ git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" version = "0.5.6" +[[deps.Strided]] +deps = ["LinearAlgebra", "StridedViews", "TupleTools"] +git-tree-sha1 = "40c69be0e1b72ee2f42923b7d1ff13e0b04e675c" +uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +version = "2.0.4" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -985,6 +1010,11 @@ git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" uuid = "781d530d-4396-4725-bb49-402e4bee1e77" version = "1.4.0" +[[deps.TupleTools]] +git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" +uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +version = "1.5.0" + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/Project.toml b/Project.toml index 0436322f..d3bab847 100644 --- a/Project.toml +++ b/Project.toml @@ -20,13 +20,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" [weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] -DeepEquilibriumNetworksZygoteExt = "Zygote" +DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"] [compat] ADTypes = "0.2.5, 1" @@ -38,6 +39,7 @@ ConstructionBase = "1" DiffEqBase = "6.119" ExplicitImports = "1.4.1" FastClosures = "0.3" +ForwardDiff = "0.10.36" Functors = "0.4.10" LinearSolve = "2.21.2" Lux = "0.5.37" diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 7b848443..688bd2ca 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -1,19 +1,58 @@ module DeepEquilibriumNetworksZygoteExt using ADTypes: AutoZygote +using ChainRulesCore: ChainRulesCore +using DeepEquilibriumNetworks: DEQs using FastClosures: @closure +using ForwardDiff: ForwardDiff # This is a dependency of Zygote +using Lux: Lux, StatefulLuxLayer using Statistics: mean using Zygote: Zygote -using DeepEquilibriumNetworks: DEQs -@inline __tupleify(u) = @closure x -> (u, x) +const CRC = ChainRulesCore + +@inline __tupleify(x) = @closure(u->(u, x)) + +## One day we will overload DI's APIs for Lux Layers and we can remove this +## Main challenge with overloading Zygote.pullback is that we need to return the correct +## tangent for the pullback to compute the correct gradient, which is quite hard. But +## wrapping the overall vjp is not that hard. +@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + return only(back(DEQs.__gaussian_like(rng, res))) +end + +function CRC.rrule( + ::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + ε = DEQs.__gaussian_like(rng, res) + y = only(back(ε)) + ∇internal_gradient_capture = Δ -> begin + (Δ isa CRC.NoTangent || Δ isa CRC.ZeroTangent) && + return ntuple(Returns(CRC.NoTangent()), 6) + + Δ_ = reshape(CRC.unthunk(Δ), size(z)) + + Tag = typeof(ForwardDiff.Tag(model, eltype(z))) + partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_)) + z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials) + + _, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps) + ∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε) + + ∂z = Lux.__partials(Tag, ∂z_duals, 1) + ∂x = Lux.__partials(Tag, ∂x_duals, 1) + ∂ps = Lux.__partials(Tag, ∂ps_duals, 1) + + return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent() + end + return y, ∇internal_gradient_capture +end ## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 ## FIXME: This will be broken in the new Lux release let's fix this function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) - res, back = Zygote.pullback(model ∘ __tupleify, z) - vjp_z = only(back(DEQs.__gaussian_like(rng, res))) - return mean(abs2, vjp_z) + return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) end end From 908c224ad12754edaafc86ac463432ca3b30f5a6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 16:40:52 -0400 Subject: [PATCH 4/6] Update the documentation --- Manifest.toml | 6 +-- docs/Project.toml | 4 +- docs/src/tutorials/basic_mnist_deq.md | 59 +++++++++------------------ docs/src/tutorials/reduced_dim_deq.md | 54 +++++++++--------------- src/layers.jl | 3 +- 5 files changed, 44 insertions(+), 82 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index c4d00d25..ca0bc6d1 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.Lux]] deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"] -git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43" +git-tree-sha1 = "ae13ecbe29ee7432dfd477b233db43c462b6a4ff" repo-rev = "ap/nested_ad" repo-url = "https://github.com/LuxDL/Lux.jl.git" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -574,9 +574,9 @@ version = "0.1.20" [[deps.LuxLib]] deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] -git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6" +git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.15" +version = "0.3.16" [deps.LuxLib.extensions] LuxLibAMDGPUExt = "AMDGPU" diff --git a/docs/Project.toml b/docs/Project.toml index 428e8bda..79874acd 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,9 @@ [deps] +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" @@ -11,6 +11,7 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2" Documenter = "1" DocumenterCitations = "1" LinearSolve = "2" -LoggingExtras = "1" Lux = "0.5" LuxCUDA = "0.3" MLDataUtils = "0.5" diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index f473ae55..644f52fe 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack ```@example basic_mnist_deq using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview @@ -20,18 +20,6 @@ const cdev = cpu_device() const gdev = gpu_device() ``` -SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress -it with the following logger - -```@example basic_mnist_deq -function remove_syms_warning(log_args) - return log_args.message != - "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead." -end - -filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger()) -``` - We can now construct our dataloader. ```@example basic_mnist_deq @@ -94,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq) x = randn(rng, Float32, 28, 28, 1, 128) y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev - model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st) - @info "warming up forward pass" + model_ = StatefulLuxLayer(model, ps, st) + @printf "[%s] warming up forward pass\n" string(now()) logitcrossentropy(model_, x, ps, y) - @info "warming up backward pass" + @printf "[%s] warming up backward pass\n" string(now()) Zygote.gradient(logitcrossentropy, model_, x, ps, y) - @info "warmup complete" + @printf "[%s] warmup complete\n" string(now()) return model, ps, st end @@ -121,7 +109,7 @@ classify(x) = argmax.(eachcol(x)) function accuracy(model, data, ps, st) total_correct, total = 0, 0 st = Lux.testmode(st) - model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model = StatefulLuxLayer(model, ps, st) for (x, y) in data target_class = classify(cdev(y)) predicted_class = classify(cdev(model(x))) @@ -134,48 +122,43 @@ end function train_model( solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) - model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) + model_st = StatefulLuxLayer(model, nothing, st) - @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))" + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) opt_st = Optimisers.setup(Adam(0.001), ps) acc = accuracy(model, data_test, ps, st) * 100 - @info "Starting Accuracy: $(acc)" + @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc - @info "Pretrain with unrolling to a depth of 5" + @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Pretraining complete. Accuracy: $(acc)" + @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for epoch in 1:3 for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Epoch: [$(epoch)/3] Accuracy: $(acc)" + @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end - @info "Training complete." - println() + @printf "[%s] Training complete.\n" string(now()) return model, ps, st end @@ -187,9 +170,7 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa from NonlinearSolve.jl. Here we will use Newton-Krylov Method: ```@example basic_mnist_deq -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq); nothing # hide ``` @@ -197,9 +178,7 @@ We can also train a continuous DEQ by passing in an ODE solver. Here we will use which tend to be quite fast for continuous Neural Network problems. ```@example basic_mnist_deq -with_logger(filtered_logger) do - train_model(VCAB3(), :deq) -end +train_model(VCAB3(), :deq); nothing # hide ``` diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 9be703d2..c91f5fcd 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size. ```@example reduced_dim_mnist using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview @@ -16,13 +16,6 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true const cdev = cpu_device() const gdev = gpu_device() -function remove_syms_warning(log_args) - return log_args.message != - "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead." -end - -filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger()) - function onehot(labels_raw) return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) end @@ -83,12 +76,12 @@ function construct_model(solver; model_type::Symbol=:regdeq) x = randn(rng, Float32, 28, 28, 1, 128) y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev - model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st) - @info "warming up forward pass" + model_ = StatefulLuxLayer(model, ps, st) + @printf "[%s] warming up forward pass\n" string(now()) logitcrossentropy(model_, x, ps, y) - @info "warming up backward pass" + @printf "[%s] warming up backward pass\n" string(now()) Zygote.gradient(logitcrossentropy, model_, x, ps, y) - @info "warmup complete" + @printf "[%s] warmup complete\n" string(now()) return model, ps, st end @@ -110,7 +103,7 @@ classify(x) = argmax.(eachcol(x)) function accuracy(model, data, ps, st) total_correct, total = 0, 0 st = Lux.testmode(st) - model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model = StatefulLuxLayer(model, ps, st) for (x, y) in data target_class = classify(cdev(y)) predicted_class = classify(cdev(model(x))) @@ -123,48 +116,43 @@ end function train_model( solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) - model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) + model_st = StatefulLuxLayer(model, nothing, st) - @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))" + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) opt_st = Optimisers.setup(Adam(0.001), ps) acc = accuracy(model, data_test, ps, st) * 100 - @info "Starting Accuracy: $(acc)" + @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc - @info "Pretrain with unrolling to a depth of 5" + @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Pretraining complete. Accuracy: $(acc)" + @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for epoch in 1:3 for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Epoch: [$(epoch)/3] Accuracy: $(acc)" + @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end - @info "Training complete." - println() + @printf "[%s] Training complete.\n" string(now()) return model, ps, st end @@ -174,15 +162,11 @@ Now we can train our model. We can't use `:regdeq` here currently, but we will s in the future. ```@example reduced_dim_mnist -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq) nothing # hide ``` ```@example reduced_dim_mnist -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq) nothing # hide ``` diff --git a/src/layers.jl b/src/layers.jl index e20b179b..466b9a64 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -316,8 +316,7 @@ julia> model(x, ps, st); """ function MultiScaleDeepEquilibriumNetwork( main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, - solver, scales; jacobian_regularization=nothing, kwargs...) - @assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models." + solver, scales; kwargs...) l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) From 031deac9b9f9ca4850ff48c92265f286ed3557d5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 18:13:40 -0400 Subject: [PATCH 5/6] Test with the new frules --- Manifest.toml | 8 +++++--- Project.toml | 1 + docs/src/tutorials/basic_mnist_deq.md | 6 ++++-- docs/src/tutorials/reduced_dim_deq.md | 6 ++++-- ext/DeepEquilibriumNetworksZygoteExt.jl | 1 - src/layers.jl | 5 ++--- src/utils.jl | 4 ++-- test/layers_tests.jl | 2 +- 8 files changed, 19 insertions(+), 14 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index ca0bc6d1..87ebbe10 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "df8a9208b4276382055ff54a66a4252730918e13" +project_hash = "914538f40e552ac89a85de7921db9eaf76294f1a" [[deps.ADTypes]] git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224" @@ -574,9 +574,11 @@ version = "0.1.20" [[deps.LuxLib]] deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] -git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf" +git-tree-sha1 = "8143e3dbdcfff587e9595b58c4b637e74c090fbf" +repo-rev = "ap/more_frules" +repo-url = "https://github.com/LuxDL/LuxLib.jl.git" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.16" +version = "0.3.17" [deps.LuxLib.extensions] LuxLibAMDGPUExt = "AMDGPU" diff --git a/Project.toml b/Project.toml index d3bab847..37d5d09d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 644f52fe..3684f4a7 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -138,7 +138,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 @@ -151,7 +152,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index c91f5fcd..9f72ac69 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -132,7 +132,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 @@ -145,7 +146,8 @@ function train_model( for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 688bd2ca..a04697e0 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -50,7 +50,6 @@ function CRC.rrule( end ## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 -## FIXME: This will be broken in the new Lux release let's fix this function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) end diff --git a/src/layers.jl b/src/layers.jl index 466b9a64..995f94db 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -314,9 +314,8 @@ julia> model(x, ps, st); ``` """ -function MultiScaleDeepEquilibriumNetwork( - main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, - solver, scales; kwargs...) +function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, + post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...) l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) diff --git a/src/utils.jl b/src/utils.jl index dfc13210..647636dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -87,8 +87,8 @@ CRC.@non_differentiable __zeros_init(::Any, ::Any) ## Don't rely on SciMLSensitivity's choice @inline __default_sensealg(prob) = nothing -@inline function __gaussian_like(rng::AbstractRNG, x) - y = similar(x) +@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray) + y = similar(x)::typeof(x) randn!(rng, y) return y end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 75b6f68d..aa19ea45 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -34,7 +34,7 @@ end jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : _jacobian_regularizations - @testset "Solver: $(__nameof(solver))" for solver in SOLVERS, + @testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS, mtype in model_type, jacobian_regularization in jacobian_regularizations From 7d9c2fafd58a2472d63fd373df0c5128441311ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Apr 2024 12:35:51 -0400 Subject: [PATCH 6/6] Remove Manifest --- Manifest.toml | 1073 ------------------------------------------------- Project.toml | 3 +- 2 files changed, 1 insertion(+), 1075 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 87ebbe10..00000000 --- a/Manifest.toml +++ /dev/null @@ -1,1073 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.2" -manifest_format = "2.0" -project_hash = "914538f40e552ac89a85de7921db9eaf76294f1a" - -[[deps.ADTypes]] -git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224" -uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -version = "1.0.0" -weakdeps = ["ChainRulesCore", "EnzymeCore"] - - [deps.ADTypes.extensions] - ADTypesChainRulesCoreExt = "ChainRulesCore" - ADTypesEnzymeCoreExt = "EnzymeCore" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "c0d491ef0b135fd7d63cbc6404286bc633329425" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.36" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "133a240faec6e074e07c31ee75619c90544179cf" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.10.0" - - [deps.ArrayInterface.extensions] - ArrayInterfaceBandedMatricesExt = "BandedMatrices" - ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" - ArrayInterfaceCUDAExt = "CUDA" - ArrayInterfaceCUDSSExt = "CUDSS" - ArrayInterfaceChainRulesExt = "ChainRules" - ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" - ArrayInterfaceReverseDiffExt = "ReverseDiff" - ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" - ArrayInterfaceTrackerExt = "Tracker" - - [deps.ArrayInterface.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" - ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" - GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.BitTwiddlingConvenienceFunctions]] -deps = ["Static"] -git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" -uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" -version = "0.1.5" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CPUSummary]] -deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] -git-tree-sha1 = "601f7e7b3d36f18790e2caf83a882d88e9b71ff1" -uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.2.4" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.CloseOpenIntervals]] -deps = ["Static", "StaticArrayInterface"] -git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" -uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" -version = "0.1.12" - -[[deps.CommonSolve]] -git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" -uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" -version = "0.2.4" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcreteStructs]] -git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" -uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -version = "0.2.3" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.5" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.CpuId]] -deps = ["Markdown"] -git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" -uuid = "adafc99b-e345-5852-983c-f28acb93d879" -version = "0.3.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DiffEqBase]] -deps = ["ArrayInterface", "ConcreteStructs", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "FastClosures", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"] -git-tree-sha1 = "531c53fd0405716712a8b4960216c3b7b5ec89b9" -uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.149.1" - - [deps.DiffEqBase.extensions] - DiffEqBaseChainRulesCoreExt = "ChainRulesCore" - DiffEqBaseDistributionsExt = "Distributions" - DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] - DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated" - DiffEqBaseMPIExt = "MPI" - DiffEqBaseMeasurementsExt = "Measurements" - DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" - DiffEqBaseReverseDiffExt = "ReverseDiff" - DiffEqBaseTrackerExt = "Tracker" - DiffEqBaseUnitfulExt = "Unitful" - - [deps.DiffEqBase.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" - Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" - GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" - MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.DiffEqCallbacks]] -deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"] -git-tree-sha1 = "ee954c8b9d348b7a8a6aec5f28288bf5adecd4ee" -uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" -version = "2.37.0" - - [deps.DiffEqCallbacks.weakdeps] - OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" - Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.EnumX]] -git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" -uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -version = "1.0.4" - -[[deps.EnzymeCore]] -git-tree-sha1 = "18394bc78ac2814ff38fe5e0c9dc2cd171e2810c" -uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.7.2" -weakdeps = ["Adapt"] - - [deps.EnzymeCore.extensions] - AdaptExt = "Adapt" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.FastBroadcast]] -deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] -git-tree-sha1 = "a6e756a880fc419c8b41592010aebe6a5ce09136" -uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" -version = "0.2.8" - -[[deps.FastClosures]] -git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" -uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -version = "0.3.2" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] -git-tree-sha1 = "2de436b72c3422940cbe1367611d137008af7ec3" -uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.23.1" - - [deps.FiniteDiff.extensions] - FiniteDiffBandedMatricesExt = "BandedMatrices" - FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" - FiniteDiffStaticArraysExt = "StaticArrays" - - [deps.FiniteDiff.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" - -[[deps.FunctionWrappersWrappers]] -deps = ["FunctionWrappers"] -git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" -uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" -version = "0.1.3" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.10" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "896385798a8d49a255c398bd49162062e4a4c435" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.13" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" -weakdeps = ["EnzymeCore"] - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.3" - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" - -[[deps.LayoutPointers]] -deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "62edfee3211981241b57ff1cedf4d74d79519277" -uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.15" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LineSearches]] -deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] -git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" -uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -version = "7.2.0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.Lux]] -deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"] -git-tree-sha1 = "ae13ecbe29ee7432dfd477b233db43c462b6a4ff" -repo-rev = "ap/nested_ad" -repo-url = "https://github.com/LuxDL/Lux.jl.git" -uuid = "b2108857-7c20-44ae-9111-449ecde12c47" -version = "0.5.38" - - [deps.Lux.extensions] - LuxComponentArraysExt = "ComponentArrays" - LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] - LuxDynamicExpressionsExt = "DynamicExpressions" - LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"] - LuxFluxExt = "Flux" - LuxForwardDiffExt = "ForwardDiff" - LuxLuxAMDGPUExt = "LuxAMDGPU" - LuxMLUtilsExt = "MLUtils" - LuxMPIExt = "MPI" - LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] - LuxOptimisersExt = "Optimisers" - LuxReverseDiffExt = "ReverseDiff" - LuxSimpleChainsExt = "SimpleChains" - LuxTrackerExt = "Tracker" - LuxZygoteExt = "Zygote" - - [deps.Lux.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" - DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" - Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" - MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" - Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[[deps.LuxCore]] -deps = ["FastClosures", "Functors", "Random", "Setfield"] -git-tree-sha1 = "f799f3aa8599f79ed5e2c9fbaf74907c1ebe15ce" -uuid = "bb33d45b-7691-41d6-9220-0943567d0623" -version = "0.1.14" - -[[deps.LuxDeviceUtils]] -deps = ["Adapt", "ChainRulesCore", "FastClosures", "Functors", "LuxCore", "PrecompileTools", "Preferences", "Random"] -git-tree-sha1 = "bbcf12d598b8ef6d2b12e506b1d18125552c3b27" -uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" -version = "0.1.20" - - [deps.LuxDeviceUtils.extensions] - LuxDeviceUtilsAMDGPUExt = "AMDGPU" - LuxDeviceUtilsCUDAExt = "CUDA" - LuxDeviceUtilsFillArraysExt = "FillArrays" - LuxDeviceUtilsGPUArraysExt = "GPUArrays" - LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" - LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" - LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] - LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" - LuxDeviceUtilsSparseArraysExt = "SparseArrays" - LuxDeviceUtilsZygoteExt = "Zygote" - - [deps.LuxDeviceUtils.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" - GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" - LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" - LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[[deps.LuxLib]] -deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] -git-tree-sha1 = "8143e3dbdcfff587e9595b58c4b637e74c090fbf" -repo-rev = "ap/more_frules" -repo-url = "https://github.com/LuxDL/LuxLib.jl.git" -uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.17" - - [deps.LuxLib.extensions] - LuxLibAMDGPUExt = "AMDGPU" - LuxLibForwardDiffExt = "ForwardDiff" - LuxLibReverseDiffExt = "ReverseDiff" - LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] - LuxLibTrackerExt = "Tracker" - LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] - LuxLibcuDNNExt = ["CUDA", "cuDNN"] - - [deps.LuxLib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.ManualMemory]] -git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" -uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" -version = "0.1.8" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.MuladdMacro]] -git-tree-sha1 = "cac9cc5499c25554cba55cd3c30543cff5ca4fab" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.4" - -[[deps.NLSolversBase]] -deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] -git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" -uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" -version = "7.8.3" - -[[deps.NLsolve]] -deps = ["Distances", "LineSearches", "LinearAlgebra", "NLSolversBase", "Printf", "Reexport"] -git-tree-sha1 = "019f12e9a1a7880459d0173c182e6a99365d7ac1" -uuid = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -version = "4.5.1" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "5055845dd316575ae2fc1f6dcb3545ff15fe547a" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.14" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.3" - -[[deps.PartialFunctions]] -deps = ["MacroTools"] -git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" -uuid = "570af359-4316-4cb7-8c74-252c00c2016b" -version = "1.2.0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.Polyester]] -deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Requires", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] -git-tree-sha1 = "2ba5f33cbb51a85ef58a850749492b08f9bf2193" -uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" -version = "0.7.13" - -[[deps.PolyesterWeave]] -deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] -git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" -uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" -version = "0.2.1" - -[[deps.PreallocationTools]] -deps = ["Adapt", "ArrayInterface", "ForwardDiff"] -git-tree-sha1 = "a660e9daab5db07adf3dedfe09b435cc530d855e" -uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" -version = "0.4.21" - - [deps.PreallocationTools.extensions] - PreallocationToolsReverseDiffExt = "ReverseDiff" - - [deps.PreallocationTools.weakdeps] - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.RecursiveArrayTools]] -deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "d8f131090f2e44b145084928856a561c83f43b27" -uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.13.0" - - [deps.RecursiveArrayTools.extensions] - RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" - RecursiveArrayToolsForwardDiffExt = "ForwardDiff" - RecursiveArrayToolsMeasurementsExt = "Measurements" - RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" - RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] - RecursiveArrayToolsTrackerExt = "Tracker" - RecursiveArrayToolsZygoteExt = "Zygote" - - [deps.RecursiveArrayTools.weakdeps] - FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" - MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.RuntimeGeneratedFunctions]] -deps = ["ExprTools", "SHA", "Serialization"] -git-tree-sha1 = "04c968137612c4a5629fa531334bb81ad5680f00" -uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" -version = "0.5.13" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.SIMDTypes]] -git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" -uuid = "94e857df-77ce-4151-89e5-788b33177be4" -version = "0.1.0" - -[[deps.SciMLBase]] -deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "beb1f94b08c4976ed1db0ca01b9e6bac89706faf" -uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.35.0" - - [deps.SciMLBase.extensions] - SciMLBaseChainRulesCoreExt = "ChainRulesCore" - SciMLBaseMakieExt = "Makie" - SciMLBasePartialFunctionsExt = "PartialFunctions" - SciMLBasePyCallExt = "PyCall" - SciMLBasePythonCallExt = "PythonCall" - SciMLBaseRCallExt = "RCall" - SciMLBaseZygoteExt = "Zygote" - - [deps.SciMLBase.weakdeps] - ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" - PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" - PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" - PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" - RCall = "6f49c342-dc21-5d91-9882-a32aef131414" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[[deps.SciMLOperators]] -deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools", "Setfield", "SparseArrays", "StaticArraysCore"] -git-tree-sha1 = "10499f619ef6e890f3f4a38914481cc868689cd5" -uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -version = "0.3.8" - -[[deps.SciMLStructures]] -git-tree-sha1 = "5833c10ce83d690c124beedfe5f621b50b02ba4d" -uuid = "53ae85a6-f571-4167-b2af-e1d143709226" -version = "1.1.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.10" - -[[deps.StaticArrayInterface]] -deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Requires", "SparseArrays", "Static", "SuiteSparse"] -git-tree-sha1 = "5d66818a39bb04bf328e92bc933ec5b4ee88e436" -uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" -version = "1.5.0" - - [deps.StaticArrayInterface.extensions] - StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" - StaticArrayInterfaceStaticArraysExt = "StaticArrays" - - [deps.StaticArrayInterface.weakdeps] - OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.SteadyStateDiffEq]] -deps = ["ConcreteStructs", "DiffEqBase", "DiffEqCallbacks", "LinearAlgebra", "Reexport", "SciMLBase"] -git-tree-sha1 = "a735fd5053724cf4de31c81b4e2cc429db844be5" -uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -version = "2.0.1" - -[[deps.StrideArraysCore]] -deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] -git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" -uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" -version = "0.5.6" - -[[deps.Strided]] -deps = ["LinearAlgebra", "StridedViews", "TupleTools"] -git-tree-sha1 = "40c69be0e1b72ee2f42923b7d1ff13e0b04e675c" -uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" -version = "2.0.4" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.SymbolicIndexingInterface]] -deps = ["Accessors", "ArrayInterface", "MacroTools", "RuntimeGeneratedFunctions", "StaticArraysCore"] -git-tree-sha1 = "40ea524431a92328cd73582d1820a5b08247a40f" -uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.16" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.ThreadingUtilities]] -deps = ["ManualMemory"] -git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" -uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" -version = "0.5.2" - -[[deps.Tricks]] -git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" -uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" -version = "0.1.8" - -[[deps.TruncatedStacktraces]] -deps = ["InteractiveUtils", "MacroTools", "Preferences"] -git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" -uuid = "781d530d-4396-4725-bb49-402e4bee1e77" -version = "1.4.0" - -[[deps.TupleTools]] -git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" -uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.5.0" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" - -[[deps.WeightInitializers]] -deps = ["ChainRulesCore", "LinearAlgebra", "PartialFunctions", "PrecompileTools", "Random", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "f0e6760ef9d22f043710289ddf29e4a4048c4822" -uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" -version = "0.1.7" - - [deps.WeightInitializers.extensions] - WeightInitializersCUDAExt = "CUDA" - - [deps.WeightInitializers.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 37d5d09d..0f57e746 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -43,7 +42,7 @@ FastClosures = "0.3" ForwardDiff = "0.10.36" Functors = "0.4.10" LinearSolve = "2.21.2" -Lux = "0.5.37" +Lux = "0.5.38" LuxCUDA = "0.3.2" LuxCore = "0.1.14" LuxTestUtils = "0.1.15"