From 356e4c903f20c6fbe5ff0e2ae917e74edc103ded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Fri, 16 Feb 2024 04:16:00 +0200 Subject: [PATCH] use `LuxCore.outputsize` --- Project.toml | 2 +- ext/SymbolicsLuxCoreExt.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 1d34a7615..157f7bc9e 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ LaTeXStrings = "1.3" LambertW = "0.4.5" Latexify = "0.16" LogExpFunctions = "0.3" -LuxCore = "0.1.7" +LuxCore = "0.1.8" MacroTools = "0.5" NaNMath = "1" PrecompileTools = "1" diff --git a/ext/SymbolicsLuxCoreExt.jl b/ext/SymbolicsLuxCoreExt.jl index ff3ead20c..d3ab063ae 100644 --- a/ext/SymbolicsLuxCoreExt.jl +++ b/ext/SymbolicsLuxCoreExt.jl @@ -3,14 +3,14 @@ module SymbolicsLuxCoreExt using LuxCore, Symbolics @register_array_symbolic LuxCore.partial_apply( - model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin - size = ((model[end].out_dims),) + model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin + size = LuxCore.outputsize(model[end]) eltype = Real end @register_array_symbolic LuxCore.partial_apply( - model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin - size = ((model.out_dims),) + model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin + size = LuxCore.outputsize(model) eltype = Real end