From cf7438dd8775286199bcfe9257dcc13c074798c8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 Nov 2024 12:42:01 +0530 Subject: [PATCH 1/3] fix: fix searching inside callable symbolics in `get_variables!` --- src/utils.jl | 3 ++- test/utils.jl | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 5fffae8a5..17d91a971 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -62,7 +62,7 @@ function is_singleton(e) op = operation(e) op === getindex && return true iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables - return issym(op) + return issym(op) && !hasmetadata(e, CallWithParent) else return issym(e) end @@ -76,6 +76,7 @@ function get_variables!(vars, e::Symbolic, varlist=nothing) push!(vars, e) end else + get_variables!(vars, operation(e), varlist) foreach(x -> get_variables!(vars, x, varlist), arguments(e)) end return (vars isa AbstractVector) ? unique!(vars) : vars diff --git a/test/utils.jl b/test/utils.jl index 727c53366..709024da5 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -19,6 +19,14 @@ using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap sorted_vars2 = Symbolics.get_variables(ex2; sort = true) @test isequal(sorted_vars2, [x, y]) + + @variables c(..) + ex3 = c(x) + c(t) - c(c(t) + y) + vars3 = Symbolics.get_variables(ex3) + @test length(vars3) == 4 + + sorted_vars3 = Symbolics.get_variables(ex3; sort = true) + @test isequal(sorted_vars3, [c.f, t, x, y]) end @testset "symbolic_to_float" begin From e393216f35f5afb4a6101be6e40ff53106bdc7fc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 Nov 2024 14:54:25 +0530 Subject: [PATCH 2/3] fix: handle removal of `Lux.NilSizePropagation.recursively_nillify` --- ext/SymbolicsLuxExt.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl index b4dfcab51..0086982c6 100644 --- a/ext/SymbolicsLuxExt.jl +++ b/ext/SymbolicsLuxExt.jl @@ -5,8 +5,10 @@ using Symbolics using Lux.LuxCore using Symbolics.SymbolicUtils -function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) - Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) +@static if isdefined(Lux.NilSizePropagation, :recursively_nillify) + function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) + Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) + end end @register_array_symbolic LuxCore.stateless_apply( From 05148595cfe0a81701410ec5f212e7743c643ed8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 Nov 2024 15:43:15 +0530 Subject: [PATCH 3/3] fix: wrap symbolic before calling `outputsize` in `LuxExt` --- ext/SymbolicsLuxExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl index 0086982c6..60e748950 100644 --- a/ext/SymbolicsLuxExt.jl +++ b/ext/SymbolicsLuxExt.jl @@ -13,7 +13,7 @@ end @register_array_symbolic LuxCore.stateless_apply( model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin - size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng()) + size = LuxCore.outputsize(model, Symbolics.wrap(x), LuxCore.Random.default_rng()) eltype = Real end