diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl index b4dfcab51..0086982c6 100644 --- a/ext/SymbolicsLuxExt.jl +++ b/ext/SymbolicsLuxExt.jl @@ -5,8 +5,10 @@ using Symbolics using Lux.LuxCore using Symbolics.SymbolicUtils -function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) - Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) +@static if isdefined(Lux.NilSizePropagation, :recursively_nillify) + function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) + Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) + end end @register_array_symbolic LuxCore.stateless_apply(