From 22b5c9db9d6b392d39ab68a48cf91aade21137ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sun, 8 Sep 2024 01:31:30 +0300 Subject: [PATCH] refactor: change the extension to Lux to support `recursively_nillify` --- Project.toml | 6 +++--- ext/SymbolicsLuxCoreExt.jl | 11 ----------- ext/SymbolicsLuxExt.jl | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 14 deletions(-) delete mode 100644 ext/SymbolicsLuxCoreExt.jl create mode 100644 ext/SymbolicsLuxExt.jl diff --git a/Project.toml b/Project.toml index c5f783318..c8c00bf01 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" @@ -52,7 +52,7 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" [extensions] SymbolicsForwardDiffExt = "ForwardDiff" SymbolicsGroebnerExt = "Groebner" -SymbolicsLuxCoreExt = "LuxCore" +SymbolicsLuxExt = "Lux" SymbolicsNemoExt = "Nemo" SymbolicsPreallocationToolsExt = ["PreallocationTools", "ForwardDiff"] SymbolicsSymPyExt = "SymPy" @@ -76,7 +76,7 @@ LaTeXStrings = "1.3" LambertW = "0.4.5" Latexify = "0.16" LogExpFunctions = "0.3" -LuxCore = "1" +Lux = "1" MacroTools = "0.5" NaNMath = "1" Nemo = "0.45, 0.46" diff --git a/ext/SymbolicsLuxCoreExt.jl b/ext/SymbolicsLuxCoreExt.jl deleted file mode 100644 index bd26b1395..000000000 --- a/ext/SymbolicsLuxCoreExt.jl +++ /dev/null @@ -1,11 +0,0 @@ -module SymbolicsLuxCoreExt - -using LuxCore, Symbolics - -@register_array_symbolic LuxCore.stateless_apply( - model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin - size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng()) - eltype = Real -end - -end diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl new file mode 100644 index 000000000..b4dfcab51 --- /dev/null +++ b/ext/SymbolicsLuxExt.jl @@ -0,0 +1,18 @@ +module SymbolicsLuxExt + +using Lux +using Symbolics +using Lux.LuxCore +using Symbolics.SymbolicUtils + +function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) + Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) +end + +@register_array_symbolic LuxCore.stateless_apply( + model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin + size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng()) + eltype = Real +end + +end