From 29f2c2a29a43552fb006accb477d77300af5c152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 14 Feb 2024 01:58:43 +0200 Subject: [PATCH] Add `stateless_apply` This calls `apply` and only returns the first argument. --- src/LuxCore.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/LuxCore.jl b/src/LuxCore.jl index c4a0be4..7dac383 100644 --- a/src/LuxCore.jl +++ b/src/LuxCore.jl @@ -126,6 +126,21 @@ Simply calls `model(x, ps, st)` """ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) +""" + stateless_apply(model, x, ps, st) + +Calls `apply` and only returns the first argument. +""" +function stateless_apply(model::AbstractExplicitLayer, x, ps, st) + u, st = apply(model, x, ps, st) + @assert isempty(st) "Model is not stateless. Use `apply` instead." + return u +end + +function stateless_apply(::StatefulLuxLayer, x, ps, st) + return error("Model is not stateless. Use `apply` instead.") +end + """ display_name(layer::AbstractExplicitLayer)