From 161b64c0d0edd0ee475fc9e6ca59fdb71681e8d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 13:50:22 -0500 Subject: [PATCH] fix: use generic broadcasting for complex numbers (#1106) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/bias_activation.jl | 5 ++++- lib/LuxLib/src/traits.jl | 7 ++++++- test/issue_tests.jl | 27 ++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 test/issue_tests.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 61bf3eb05b..aa927150be 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.9" +version = "1.3.10" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index f96531a7d7..789b15a31e 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -80,7 +80,8 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), return y, ∇bias_activation_rrule end - y, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ ∘ +, x, reshape_bias(x, bias)) + y, ∇broadcast = CRC.rrule_via_ad( + cfg, broadcast_bias_activation_generic, σ, x, reshape_bias(x, bias)) ∇bias_activation_rrule = @closure Δ -> begin _, _, ∂x, ∂bias = ∇broadcast(Δ) return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(vec(∂bias)) @@ -88,6 +89,8 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), return y, ∇bias_activation_rrule end +@inline broadcast_bias_activation_generic(σ::F, x, b) where {F} = σ.(x .+ b) + bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index e193817b52..0a768de86f 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -21,7 +21,7 @@ is_mutable_array(::Nothing) = True() ChainRulesCore.@non_differentiable is_mutable_array(::Any...) -for op in (:has_dual, :has_float16, :is_tracked) +for op in (:has_dual, :has_float16, :is_tracked, :has_complex) @eval $op(::Nothing) = False() @eval $op(x::Numeric) = $op(eltype(x)) end @@ -38,6 +38,9 @@ has_dual(::Type{<:ForwardDiff.Dual}) = True() has_float16(_) = False() has_float16(::Type{<:Float16}) = True() +has_complex(_) = False() +has_complex(::Type{<:Complex}) = True() + is_tracked(_) = False() has_autodiff_value(x) = is_tracked(x) | has_dual(x) @@ -51,6 +54,7 @@ function use_generic_broadcasting(xs::Tuple) xs_unwrapped = unrolled_map(unwrap_array, xs) return unrolled_any(has_autodiff_value, xs_unwrapped) | unrolled_any(has_float16, xs_unwrapped) | + unrolled_any(has_complex, xs_unwrapped) | unrolled_any(static_isa(StaticArray), xs_unwrapped) end @@ -198,6 +202,7 @@ Currently supported modes are: + ReverseDiff Arrays + Tracker Arrays + ForwardDiff.Dual Arrays + + Complex Arrays - `GPUBroadcastOp{dev}`: GPU Arrays where `dev` is obtained from `get_device_type(xs)`. This option dispatches should preferably use `KernelAbstractions` or specialized vendor diff --git a/test/issue_tests.jl b/test/issue_tests.jl new file mode 100644 index 0000000000..1de1519850 --- /dev/null +++ b/test/issue_tests.jl @@ -0,0 +1,27 @@ +@testitem "complex differentiation: issue #977" tags=[:misc] begin + using Lux, Zygote, Random + + rng = Random.default_rng() + Random.seed!(rng, 666) + + rbf(x) = exp.(-(x .^ 2)) + + U = Lux.Chain( + Lux.Dense(1, 10, rbf), + Lux.Dense(10, 3, rbf) + ) + + θ, st = Lux.setup(rng, U) + + function complex_step_differentiation(f::Function, x::Float64, ϵ::Float64) + return imag(f(x + ϵ * im)) / ϵ + end + + loss(t) = sum(complex_step_differentiation(τ -> U([τ], θ, st)[begin], t, 1e-5)) + + if pkgversion(LuxLib) ≥ v"1.3.10" + @test only(Zygote.gradient(loss, 1.0)) isa Float64 + else + @test_broken only(Zygote.gradient(loss, 1.0)) isa Float64 + end +end