Skip to content

Commit

Permalink
fixup! refactor: use RealInputArray and RealOutputArray
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Apr 19, 2024
1 parent 1f281f6 commit 37890ff
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module ModelingToolkitNeuralNets

using ModelingToolkit: @parameters, @named, ODESystem, t_nounits, @connector, @variables,
Equation
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
using LuxCore: stateless_apply
Expand Down Expand Up @@ -31,17 +30,17 @@ function NeuralNetworkBlock(n_input = 1,
ca = ComponentArray{eltype}(init_params)

@parameters p[1:length(ca)] = Vector(ca)
# @parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]

@named input = RealInputArray(nin = n_input)
@named output = RealOutputArray(nout = n_output)

out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
out = stateless_apply(chain, input.u, lazyconvert(T, p))

eqs = [output.u ~ out]

@named ude_comp = ODESystem(
eqs, t_nounits, [], [p], systems = [input, output])
eqs, t_nounits, [], [p, T], systems = [input, output])
return ude_comp
end

Expand Down

0 comments on commit 37890ff

Please sign in to comment.