diff --git a/src/dae_solve.jl b/src/dae_solve.jl index f2c861da02..755a3abbe0 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -43,6 +43,7 @@ end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) + !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end