diff --git a/Project.toml b/Project.toml index 1d34a7615..157f7bc9e 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ LaTeXStrings = "1.3" LambertW = "0.4.5" Latexify = "0.16" LogExpFunctions = "0.3" -LuxCore = "0.1.7" +LuxCore = "0.1.8" MacroTools = "0.5" NaNMath = "1" PrecompileTools = "1" diff --git a/ext/SymbolicsLuxCoreExt.jl b/ext/SymbolicsLuxCoreExt.jl index ff3ead20c..d3ab063ae 100644 --- a/ext/SymbolicsLuxCoreExt.jl +++ b/ext/SymbolicsLuxCoreExt.jl @@ -3,14 +3,14 @@ module SymbolicsLuxCoreExt using LuxCore, Symbolics @register_array_symbolic LuxCore.partial_apply( - model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin - size = ((model[end].out_dims),) + model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin + size = LuxCore.outputsize(model[end]) eltype = Real end @register_array_symbolic LuxCore.partial_apply( - model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin - size = ((model.out_dims),) + model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin + size = LuxCore.outputsize(model) eltype = Real end