From 2a20c5a39b973334e797203ac0318dd32b23b8c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Apr 2024 12:29:51 -0400 Subject: [PATCH] Revert "Test if ChainRules problem is resolved" This reverts commit 1daeb5956d026dd4b1a5790ff35385d39a6a5860. --- Project.toml | 6 +++++- ext/LuxChainRulesExt.jl | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 ext/LuxChainRulesExt.jl diff --git a/Project.toml b/Project.toml index 4e6be37a88..b6da5a9a69 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.34" +version = "0.5.33" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -29,6 +29,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -42,6 +43,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +LuxChainRulesExt = "ChainRules" LuxComponentArraysExt = "ComponentArrays" LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] LuxFluxExt = "Flux" @@ -61,6 +63,7 @@ Adapt = "4" Aqua = "0.8.4" ArrayInterface = "7.8" CUDA = "5.2" +ChainRules = "1.62" ChainRulesCore = "1.21" ComponentArrays = "0.15.11" ConcreteStructs = "0.2.3" @@ -105,6 +108,7 @@ julia = "1.10" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" diff --git a/ext/LuxChainRulesExt.jl b/ext/LuxChainRulesExt.jl new file mode 100644 index 0000000000..9ecf874d01 --- /dev/null +++ b/ext/LuxChainRulesExt.jl @@ -0,0 +1,18 @@ +module LuxChainRulesExt + +using ChainRules: ChainRules + +# https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an +# emergency patch here +function ChainRules._setindex_zero( + x::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} + return [fill!(similar(xᵢ), 0) for xᵢ in x] +end + +function ChainRules.∇getindex!( + dx::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} + dx[inds...] .+= dy + return dx +end + +end