From c42afdee1005bca6f6b0babaf17211e16e39cc19 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 09:53:35 -0700 Subject: [PATCH] fix: remove type pirated functions from Lux --- Project.toml | 4 ++-- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 1 - ext/LuxReverseDiffExt/apply.jl | 16 ---------------- ext/LuxTrackerExt/LuxTrackerExt.jl | 1 - ext/LuxTrackerExt/apply.jl | 14 -------------- src/chainrules.jl | 8 -------- 6 files changed, 2 insertions(+), 42 deletions(-) delete mode 100644 ext/LuxReverseDiffExt/apply.jl delete mode 100644 ext/LuxTrackerExt/apply.jl diff --git a/Project.toml b/Project.toml index 18c4bf153..7fa453e04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.63" +version = "0.5.64-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -86,7 +86,7 @@ Functors = "0.4.12" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" -LuxCore = "0.1.16" +LuxCore = "0.1.24" LuxDeviceUtils = "0.1.26" LuxLib = "0.3.40" MLUtils = "0.4.3" diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 35eecb733..5627e789f 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -9,7 +9,6 @@ using LuxCore: LuxCore, AbstractExplicitLayer using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules -include("apply.jl") include("utils.jl") include("rules.jl") include("training.jl") diff --git a/ext/LuxReverseDiffExt/apply.jl b/ext/LuxReverseDiffExt/apply.jl deleted file mode 100644 index 94865957f..000000000 --- a/ext/LuxReverseDiffExt/apply.jl +++ /dev/null @@ -1,16 +0,0 @@ -# TODO: move to LuxCore -# AoS to SoA conversion -function LuxCore.apply( - m::AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) - @warn "Lux.apply(m::AbstractExplicitLayer, \ - x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ - st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return LuxCore.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) -end - -## Prevent an infinite loop -LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) diff --git a/ext/LuxTrackerExt/LuxTrackerExt.jl b/ext/LuxTrackerExt/LuxTrackerExt.jl index a5cfa8110..14c74fbab 100644 --- a/ext/LuxTrackerExt/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt/LuxTrackerExt.jl @@ -10,7 +10,6 @@ using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules const CRC = ChainRulesCore -include("apply.jl") include("utils.jl") include("rules.jl") include("training.jl") diff --git a/ext/LuxTrackerExt/apply.jl b/ext/LuxTrackerExt/apply.jl deleted file mode 100644 index 0ebce4707..000000000 --- a/ext/LuxTrackerExt/apply.jl +++ /dev/null @@ -1,14 +0,0 @@ -# TODO: move to LuxCore -# AoS to SoA conversion -function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) - @warn "LuxCore.apply(m::AbstractExplicitLayer, \ - x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ - LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st) -end - -## Prevent an infinite loop -LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) diff --git a/src/chainrules.jl b/src/chainrules.jl index 2e72c018a..8cd444d06 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -71,14 +71,6 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(foldl_init), return y, ∇foldl_init end -# getproperty rrule for AbstractExplicitLayer. needed for type stability of Zygote -# gradients -function CRC.rrule(::typeof(getproperty), m::AbstractExplicitLayer, name::Symbol) - res = getproperty(m, name) - ∇getproperty = Δ -> ntuple(Returns(NoTangent()), 3) - return res, ∇getproperty -end - # Loss Functions @inline function CRC.rrule( ::typeof(__fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y)