Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Housekeeping + Use Faster Nested AD #152

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ whitespace_in_kwargs = false
format_docstrings = true
separate_kwargs_with_semicolon = true
format_markdown = true
annotate_untyped_fields_with_any = false
annotate_untyped_fields_with_any = false
join_lines_based_on_source = false
2 changes: 2 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ steps:
timeout_in_minutes: 240

env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ jobs:
env:
GROUP: "CPU"
JULIA_NUM_THREADS: 12
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down
File renamed without changes.
57 changes: 44 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,49 +1,80 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "2.0.3"
version = "2.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
ADTypes = "0.2.5"
ADTypes = "0.2.5, 1"
Aqua = "0.8.7"
ChainRulesCore = "1"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
FastClosures = "0.3"
LinearAlgebra = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.11"
Lux = "0.5.38"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
NLsolve = "4.5.1"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Random = "1"
Random = "1.10"
ReTestItems = "1.23.1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
Statistics = "1"
StableRNGs = "1.0.2"
Statistics = "1.10"
SteadyStateDiffEq = "2"
TruncatedStacktraces = "1.1"
Zygote = "0.6.67"
julia = "1.9"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ Random.seed!(rng, seed)

model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
NewtonRaphson()))

gdev = gpu_device()
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2"
Documenter = "1"
DocumenterCitations = "1"
LinearSolve = "2"
LoggingExtras = "1"
Lux = "0.5"
LuxCUDA = "0.3"
MLDataUtils = "0.5"
Expand Down
11 changes: 8 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)

include("pages.jl")

makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true,
makedocs(; sitename="Deep Equilibrium Networks",
authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks],
clean=true,
doctest=true,
linkcheck=true,
format=Documenter.HTML(; assets=["assets/favicon.ico"],
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
plugins=[bib], pages)
plugins=[bib],
pages)

deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)
12 changes: 3 additions & 9 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
pages = [
"Home" => "index.md",
"Tutorials" => [
"tutorials/basic_mnist_deq.md",
"tutorials/reduced_dim_deq.md"
],
"API References" => "api.md",
"References" => "references.md"
]
pages = ["Home" => "index.md",
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
"API References" => "api.md", "References" => "references.md"]
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ Random.seed!(rng, seed)

model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
NewtonRaphson()))

gdev = gpu_device()
Expand Down
76 changes: 28 additions & 48 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview

Expand All @@ -20,18 +20,6 @@ const cdev = cpu_device()
const gdev = gpu_device()
```

SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
it with the following logger

```@example basic_mnist_deq
function remove_syms_warning(log_args)
return log_args.message !=
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
end

filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
```

We can now construct our dataloader.

```@example basic_mnist_deq
Expand Down Expand Up @@ -66,8 +54,7 @@ function construct_model(solver; model_type::Symbol=:deq)

# The input layer of the DEQ
deq_model = Chain(
Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))

Expand All @@ -79,11 +66,11 @@ function construct_model(solver; model_type::Symbol=:deq)
init = missing
end

deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10))
deq = DeepEquilibriumNetwork(
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))

classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
Dense(64, 10))
classifier = Chain(
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))

model = Chain(; down, deq, classifier)

Expand All @@ -95,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev

model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@info "warming up forward pass"
model_ = StatefulLuxLayer(model, ps, st)
@printf "[%s] warming up forward pass\n" string(now())
logitcrossentropy(model_, x, ps, y)
@info "warming up backward pass"
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
@info "warmup complete"
@printf "[%s] warmup complete\n" string(now())

return model, ps, st
end
Expand All @@ -122,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model = StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
Expand All @@ -132,51 +119,48 @@ function accuracy(model, data, ps, st)
return total_correct / total
end

function train_model(solver, model_type; data_train=zip(x_train, y_train),
data_test=zip(x_test, y_test))
function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
model_st = StatefulLuxLayer(model, nothing, st)

@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))

opt_st = Optimisers.setup(Adam(0.001), ps)

acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc

@info "Pretrain with unrolling to a depth of 5"
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)

for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 &&
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
end

acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc

st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)

for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 &&
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
end

acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end

@info "Training complete."
println()
@printf "[%s] Training complete.\n" string(now())

return model, ps, st
end
Expand All @@ -188,19 +172,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:

```@example basic_mnist_deq
with_logger(filtered_logger) do
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
end
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
nothing # hide
```

We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
which tend to be quite fast for continuous Neural Network problems.

```@example basic_mnist_deq
with_logger(filtered_logger) do
train_model(VCAB3(), :deq)
end
train_model(VCAB3(), :deq);
nothing # hide
```

Expand Down
Loading
Loading