From b2b36f60c2a5902c005bbaa0ed81697634090c15 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Sep 2023 19:07:21 -0400 Subject: [PATCH 1/2] Upstream CA patches for AD Packages --- Project.toml | 6 ++--- ext/LuxComponentArraysTrackerExt.jl | 35 ----------------------------- ext/LuxComponentArraysZygoteExt.jl | 9 -------- 3 files changed, 2 insertions(+), 48 deletions(-) delete mode 100644 ext/LuxComponentArraysTrackerExt.jl delete mode 100644 ext/LuxComponentArraysZygoteExt.jl diff --git a/Project.toml b/Project.toml index 54e620e960..93bba56588 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.6" +version = "0.5.7" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -38,8 +38,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] LuxComponentArraysExt = "ComponentArrays" LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] -LuxComponentArraysTrackerExt = ["ComponentArrays", "Tracker"] -LuxComponentArraysZygoteExt = ["ComponentArrays", "Zygote"] LuxFluxTransformExt = "Flux" LuxLuxAMDGPUExt = "LuxAMDGPU" LuxLuxCUDAExt = "LuxCUDA" @@ -50,7 +48,7 @@ LuxZygoteExt = "Zygote" ADTypes = "0.1, 0.2" Adapt = "3" ChainRulesCore = "1" -ComponentArrays = "0.13, 0.14, 0.15" +ComponentArrays = "0.15.1" ConcreteStructs = "0.2" FillArrays = "0.13, 1" Flux = "0.13, 0.14" diff --git a/ext/LuxComponentArraysTrackerExt.jl b/ext/LuxComponentArraysTrackerExt.jl deleted file mode 100644 index e522e703af..0000000000 --- a/ext/LuxComponentArraysTrackerExt.jl +++ /dev/null @@ -1,35 +0,0 @@ -module LuxComponentArraysTrackerExt - -using ComponentArrays, Tracker - -function Tracker.param(ca::ComponentArray) - x = getdata(ca) - length(x) == 0 && return ComponentArray(Tracker.param(Float32[]), getaxes(ca)) - return ComponentArray(Tracker.param(x), getaxes(ca)) -end - -Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) - -function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing, - typeof(zero), <:Tuple{<:ComponentVector}}) - ca = first(bc.args) - return ComponentArray(zero.(getdata(ca)), getaxes(ca)) -end - -function Base.getindex(g::Tracker.Grads, x::ComponentArray) - Tracker.istracked(getdata(x)) || error("Object not tracked: $x") - return g[Tracker.tracker(getdata(x))] -end - -# For TrackedArrays ignore Base.maybeview -## Tracker with views doesn't work quite well -@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, - s::Symbol) where {T} - return getproperty(x, Val(s)) -end - -@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, v::Val) where {T} - return ComponentArrays._getindex(Base.getindex, x, v) -end - -end diff --git a/ext/LuxComponentArraysZygoteExt.jl b/ext/LuxComponentArraysZygoteExt.jl deleted file mode 100644 index 452871c3ea..0000000000 --- a/ext/LuxComponentArraysZygoteExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module LuxComponentArraysZygoteExt - -using ComponentArrays, Zygote - -function Zygote.accum(x::ComponentArray, ys::ComponentArray...) - return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x)) -end - -end From f84548f4bd336280a0d191da671c1551eaf69dca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Sep 2023 19:40:00 -0400 Subject: [PATCH 2/2] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 93bba56588..8462eae186 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ LuxZygoteExt = "Zygote" ADTypes = "0.1, 0.2" Adapt = "3" ChainRulesCore = "1" -ComponentArrays = "0.15.1" +ComponentArrays = "0.15.2" ConcreteStructs = "0.2" FillArrays = "0.13, 1" Flux = "0.13, 0.14"