diff --git a/Manifest.toml b/Manifest.toml index 142f0567..c4d00d25 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.Lux]] deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"] -git-tree-sha1 = "d7f49df9abfbb372fcbde5f41e547aa3679e9793" +git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43" repo-rev = "ap/nested_ad" repo-url = "https://github.com/LuxDL/Lux.jl.git" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -573,12 +573,13 @@ version = "0.1.20" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [[deps.LuxLib]] -deps = ["ChainRulesCore", "FastClosures", "KernelAbstractions", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "b1f81a8aa8313c1f1b4cbfb18733db17c023427e" +deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"] +git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "0.3.14" +version = "0.3.15" [deps.LuxLib.extensions] + LuxLibAMDGPUExt = "AMDGPU" LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -684,6 +685,12 @@ git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.3" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -927,6 +934,24 @@ git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" version = "0.5.6" +[[deps.Strided]] +deps = ["LinearAlgebra", "StridedViews", "TupleTools"] +git-tree-sha1 = "40c69be0e1b72ee2f42923b7d1ff13e0b04e675c" +uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +version = "2.0.4" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -985,6 +1010,11 @@ git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" uuid = "781d530d-4396-4725-bb49-402e4bee1e77" version = "1.4.0" +[[deps.TupleTools]] +git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" +uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +version = "1.5.0" + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/Project.toml b/Project.toml index 0436322f..d3bab847 100644 --- a/Project.toml +++ b/Project.toml @@ -20,13 +20,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" [weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] -DeepEquilibriumNetworksZygoteExt = "Zygote" +DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"] [compat] ADTypes = "0.2.5, 1" @@ -38,6 +39,7 @@ ConstructionBase = "1" DiffEqBase = "6.119" ExplicitImports = "1.4.1" FastClosures = "0.3" +ForwardDiff = "0.10.36" Functors = "0.4.10" LinearSolve = "2.21.2" Lux = "0.5.37" diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 7b848443..688bd2ca 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -1,19 +1,58 @@ module DeepEquilibriumNetworksZygoteExt using ADTypes: AutoZygote +using ChainRulesCore: ChainRulesCore +using DeepEquilibriumNetworks: DEQs using FastClosures: @closure +using ForwardDiff: ForwardDiff # This is a dependency of Zygote +using Lux: Lux, StatefulLuxLayer using Statistics: mean using Zygote: Zygote -using DeepEquilibriumNetworks: DEQs -@inline __tupleify(u) = @closure x -> (u, x) +const CRC = ChainRulesCore + +@inline __tupleify(x) = @closure(u->(u, x)) + +## One day we will overload DI's APIs for Lux Layers and we can remove this +## Main challenge with overloading Zygote.pullback is that we need to return the correct +## tangent for the pullback to compute the correct gradient, which is quite hard. But +## wrapping the overall vjp is not that hard. +@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + return only(back(DEQs.__gaussian_like(rng, res))) +end + +function CRC.rrule( + ::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + ε = DEQs.__gaussian_like(rng, res) + y = only(back(ε)) + ∇internal_gradient_capture = Δ -> begin + (Δ isa CRC.NoTangent || Δ isa CRC.ZeroTangent) && + return ntuple(Returns(CRC.NoTangent()), 6) + + Δ_ = reshape(CRC.unthunk(Δ), size(z)) + + Tag = typeof(ForwardDiff.Tag(model, eltype(z))) + partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_)) + z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials) + + _, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps) + ∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε) + + ∂z = Lux.__partials(Tag, ∂z_duals, 1) + ∂x = Lux.__partials(Tag, ∂x_duals, 1) + ∂ps = Lux.__partials(Tag, ∂ps_duals, 1) + + return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent() + end + return y, ∇internal_gradient_capture +end ## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 ## FIXME: This will be broken in the new Lux release let's fix this function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) - res, back = Zygote.pullback(model ∘ __tupleify, z) - vjp_z = only(back(DEQs.__gaussian_like(rng, res))) - return mean(abs2, vjp_z) + return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) end end