Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpect result when use build_function for Lux.jl-based expr #1187

Closed
chooron opened this issue Jul 4, 2024 · 6 comments
Closed

unexpect result when use build_function for Lux.jl-based expr #1187

chooron opened this issue Jul 4, 2024 · 6 comments

Comments

@chooron
Copy link

chooron commented Jul 4, 2024

I want to combine multiple expressions to build a function without intermediate calculated values. Refer to https://discourse.julialang.org/t/how-can-i-improve-the-computational-performance-of-my-code-odeproblem-solving/114425/9. When these expressions are ordinary variables, it is possible to complete the construction of this function by replacing intermediate variables. However, when expressions are generated by Lux.jl models, the build_function does not work well:

using Lux
using LuxCore
using Symbolics
using StableRNGs
using ModelingToolkit
using ModelingToolkitNeuralNets
using ModelingToolkit: t_nounits as t
using ComponentArrays
using ModelingToolkitStandardLibrary.Blocks: RealInputArray
using RuntimeGeneratedFunctions
using BenchmarkTools

nn = Lux.Chain(
    Lux.Dense(3 => 16), # , Lux.tanh
    Lux.Dense(16 => 16), # , Lux.leakyrelu
    Lux.Dense(16 => 2)#, Lux.leakyrelu
)

init_params = Lux.initialparameters(StableRNG(42), nn)
init_states = Lux.initialstates(StableRNG(42), nn)
@parameters p[1:length(init_params)] = Vector(ComponentVector(init_params))
@parameters ptype::typeof(typeof(init_params)) = typeof(init_params) [tunable = false]
lazyconvert_p = Symbolics.array_term(convert, ptype, p, size=size(p))

@variables v1(t), v2(t), v3(t), v4(t)
exprs = LuxCore.stateless_apply(nn, [v1, v2, v3, v4], lazyconvert_p)[1]

The generated function temp_func = build_function(exprs, [v1, v2, v3, v4], expression=Val{false}) is:

RuntimeGeneratedFunction(#=in Symbolics=#, #=using Symbolics=#, :((ˍ₋arg1,)->begin
          #= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:373 =#
          #= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:374 =#
          #= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:375 =#
          begin
              (getindex)((LuxCore.stateless_apply)(Chain(), Num[v1(t), v2(t), v3(t), v4(t)], (convert)(ptype, p)), 1)
          end
      end))

In this function, some variables such as v1, v2, etc., are not replaced by _arg1. Additionally, when calling this function temp_func([3, 2, 1, 2], [ComponentVector(init_params), eltype(ComponentVector(init_params))]), there is also a type conversion issue:

ERROR: MethodError: Cannot `convert` an object of type ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:48, ShapedAxis((16, 3))), bias = ViewAxis(49:64, ShapedAxis((16, 1))))), layer_2 = ViewAxis(65:336, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(337:370, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = ViewAxis(33:34, ShapedAxis((2, 1))))))}}} to an object of type Float32
@ChrisRackauckas
Copy link
Member

It has a registration rule for symbolic arrays, but arrays of symbolic variables is missing form the dispatch.

@chooron
Copy link
Author

chooron commented Jul 6, 2024

I built a symbolic array and found that it can indeed be replaced, as shown below:

@variables (v(t))[1:4]
exprs = LuxCore.stateless_apply(nn, v, lazyconvert_p)[1]
temp_func = build_function(exprs, [v, ptype, p], expression=Val{false})

But when I call this function, there is still a type conversion problem:
ERROR: MethodError: Cannot convert an object of type ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:48, ShapedAxis((16, 3))), bias = ViewAxis(49:64, ShapedAxis((16, 1))))), layer_2 = ViewAxis(65:336, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(337:370, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = ViewAxis(33:34, ShapedAxis((2, 1))))))}}} to an object of type Float32

@ChrisRackauckas
Copy link
Member

@SebastianM-C can you point how this is different from the ModelingToolkitNeuralNets code?

@SebastianM-C
Copy link
Contributor

Yes, there are a couple things to note

  • The Lux.Chain definition has an input size of 3 on the first layer, but we construct an input vector with length 4. This will not cause an error, but it's likely incorrect.
  • ptype needs to be the type of the ComponentVector(init_params), as you need to convert a vector to that and you can't convert a Vector to a NamedTuple.
  • How are you calling the function? Considering the last message, you should be using something like temp_func([[1.0, 2.0, 3.0], typeof(ComponentVector(init_params)), Vector(ComponentVector(init_params))])

Also, I'm not sure why are you using only the first element of the stateless_apply output. Is this for debugging purposes or do you only need the first element of the neural network output?

@chooron
Copy link
Author

chooron commented Jul 7, 2024

Thanks, here is my response:

  • This was my oversight, I didn't pay attention to the dimensions. I have corrected it later.
  • After changing eltype to typeof, it indeed worked.
  • I want to manually construct an ODE function. The method I use is to build the equation symbolically and then replace the intermediate variables one by one (see here). So, I want to use Lux.Chain to calculate the result at a certain point.
    Here is the modified code:
using Lux
using LuxCore
using Symbolics
using StableRNGs
using ModelingToolkit
using ModelingToolkitNeuralNets
using ModelingToolkit: t_nounits as t
using ComponentArrays
using BenchmarkTools

nn = Lux.Chain(
    Lux.Dense(4 => 16), # , Lux.tanh
    Lux.Dense(16 => 16), # , Lux.leakyrelu
    Lux.Dense(16 => 2)#, Lux.leakyrelu
)

init_params = Lux.initialparameters(StableRNG(42), nn)
init_states = Lux.initialstates(StableRNG(42), nn)
@parameters p[1:length(init_params)] = Vector(ComponentVector(init_params))
@parameters ptype::typeof(typeof(init_params)) = typeof(init_params) [tunable = false]
lazyconvert_p = Symbolics.array_term(convert, ptype, p, size=size(p))

@variables v1(t), v2(t), v3(t), v4(t)
@variables (v(t))[1:4]
exprs = LuxCore.stateless_apply(nn, v, lazyconvert_p)[1]

temp_func = build_function(exprs, v, [p, ptype], expression=Val{false})
temp_func([3, 2, 1, 2], [ComponentVector(init_params), typeof(ComponentVector(init_params))])

This works, but I still want to construct it in a way like LuxCore.stateless_apply(nn, [v1, v2, v3, v4], lazyconvert_p)[1] because the indices of the vector v are calculated separately using different formulas. I want to combine the results into a vector as the input to the chain.

@chooron
Copy link
Author

chooron commented Jul 16, 2024

Solved by using SymbolicUtils.jl, see the code as below:

using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
using Symbolics
using SymbolicUtils
using SymbolicUtils.Code
using Lux
using LuxCore
using StableRNGs
using ModelingToolkit
using ComponentArrays

@variables a b c d

chain = Lux.Chain(Lux.Dense(2, 16), Lux.Dense(16, 2))
init_params = ComponentVector(Lux.initialparameters(StableRNG(42), chain))
chain_params = first(@parameters p[1:length(init_params)] = Vector(ComponentVector(init_params)))
@parameters ptype::typeof(typeof(init_params)) = typeof(init_params) [tunable = false]
lazyconvert_params = Symbolics.array_term(convert, ptype, chain_params, size=size(chain_params))
# lazyconvert_params = Symbolics.array_term((x, axes) -> ComponentVector(x, axes), chain_params, getaxes(init_params), size=size(chain_params))

@variables nn_in[1:2]
@variables nn_out[1:2]

expr = LuxCore.stateless_apply(chain, nn_in, lazyconvert_params)
ass = [Assignment(nn_in, MakeArray([a, b], Vector)), Assignment(nn_out, expr), Assignment(c, nn_out[1]), Assignment(d, nn_out[2])]
let_ = Let(ass, c + d, false)
func = @RuntimeGeneratedFunction(
    toexpr(Func([a, b, ptype, p], [], let_))
)
func(1, 2, typeof(ComponentVector(init_params)), collect(init_params))

@chooron chooron closed this as completed Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants