Skip to content

Commit

Permalink
use LuxCore.outputsize
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Feb 16, 2024
1 parent 9dbd42d commit 356e4c9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.7"
LuxCore = "0.1.8"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand Down
8 changes: 4 additions & 4 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module SymbolicsLuxCoreExt
using LuxCore, Symbolics

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin
size = ((model[end].out_dims),)
model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin
size = LuxCore.outputsize(model[end])
eltype = Real
end

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin
size = ((model.out_dims),)
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin
size = LuxCore.outputsize(model)
eltype = Real
end

Expand Down

0 comments on commit 356e4c9

Please sign in to comment.