-
-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NNODE training fails with autodiff=true
#725
Comments
@ChrisRackauckas is this a known issue? |
It wasn't but now it is. |
Updated MWE: using Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import OptimizationOptimisers
Random.seed!(100)
# Run a solve on scalars
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0, tspan)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))
sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200) does not error out with the same error. (There is a check which errors out if autodiff is true in #783) Removing the check gives me this: julia> sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}, @NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}}}}, Float32, Float32, Nothing}, Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/forward.jl:150
Current loss is: 121.3777458193083, Iteration: 1
Current loss is: 121.3777458193083, Iteration: 2
Current loss is: 121.3777458193083, Iteration: 3
Current loss is: 121.3777458193083, Iteration: 4
Current loss is: 121.3777458193083, Iteration: 5
Current loss is: 121.3777458193083, Iteration: 6
Current loss is: 121.3777458193083, Iteration: 7
Current loss is: 121.3777458193083, Iteration: 8
Current loss is: 121.3777458193083, Iteration: 9
Current loss is: 121.3777458193083, Iteration: 10
Current loss is: 121.3777458193083, Iteration: 11
Current loss is: 121.3777458193083, Iteration: 12
Current loss is: 121.3777458193083, Iteration: 13
Current loss is: 121.3777458193083, Iteration: 14
Current loss is: 121.3777458193083, Iteration: 15
Current loss is: 121.3777458193083, Iteration: 16
Current loss is: 121.3777458193083, Iteration: 17
Current loss is: 121.3777458193083, Iteration: 18
Current loss is: 121.3777458193083, Iteration: 19
Current loss is: 121.3777458193083, Iteration: 20
Current loss is: 121.3777458193083, Iteration: 21
Current loss is: 121.3777458193083, Iteration: 22
Current loss is: 121.3777458193083, Iteration: 23
Current loss is: 121.3777458193083, Iteration: 24
Current loss is: 121.3777458193083, Iteration: 25
Current loss is: 121.3777458193083, Iteration: 26
Current loss is: 121.3777458193083, Iteration: 27
Current loss is: 121.3777458193083, Iteration: 28
Current loss is: 121.3777458193083, Iteration: 29
Current loss is: 121.3777458193083, Iteration: 30
Current loss is: 121.3777458193083, Iteration: 31
Current loss is: 121.3777458193083, Iteration: 32
Current loss is: 121.3777458193083, Iteration: 33
Current loss is: 121.3777458193083, Iteration: 34
Current loss is: 121.3777458193083, Iteration: 35
Current loss is: 121.3777458193083, Iteration: 36
Current loss is: 121.3777458193083, Iteration: 37
Current loss is: 121.3777458193083, Iteration: 38
Current loss is: 121.3777458193083, Iteration: 39
Current loss is: 121.3777458193083, Iteration: 40
Current loss is: 121.3777458193083, Iteration: 41
Current loss is: 121.3777458193083, Iteration: 42
Current loss is: 121.3777458193083, Iteration: 43
Current loss is: 121.3777458193083, Iteration: 44
Current loss is: 121.3777458193083, Iteration: 45
Current loss is: 121.3777458193083, Iteration: 46
Current loss is: 121.3777458193083, Iteration: 47
Current loss is: 121.3777458193083, Iteration: 48
Current loss is: 121.3777458193083, Iteration: 49
Current loss is: 121.3777458193083, Iteration: 50
Current loss is: 121.3777458193083, Iteration: 51
Current loss is: 121.3777458193083, Iteration: 52
Current loss is: 121.3777458193083, Iteration: 53
Current loss is: 121.3777458193083, Iteration: 54
Current loss is: 121.3777458193083, Iteration: 55
Current loss is: 121.3777458193083, Iteration: 56
Current loss is: 121.3777458193083, Iteration: 57
Current loss is: 121.3777458193083, Iteration: 58
Current loss is: 121.3777458193083, Iteration: 59
Current loss is: 121.3777458193083, Iteration: 60
Current loss is: 121.3777458193083, Iteration: 61
Current loss is: 121.3777458193083, Iteration: 62
Current loss is: 121.3777458193083, Iteration: 63
Current loss is: 121.3777458193083, Iteration: 64
Current loss is: 121.3777458193083, Iteration: 65
Current loss is: 121.3777458193083, Iteration: 66
Current loss is: 121.3777458193083, Iteration: 67
Current loss is: 121.3777458193083, Iteration: 68
Current loss is: 121.3777458193083, Iteration: 69
Current loss is: 121.3777458193083, Iteration: 70
Current loss is: 121.3777458193083, Iteration: 71
Current loss is: 121.3777458193083, Iteration: 72
Current loss is: 121.3777458193083, Iteration: 73
Current loss is: 121.3777458193083, Iteration: 74
Current loss is: 121.3777458193083, Iteration: 75
Current loss is: 121.3777458193083, Iteration: 76
Current loss is: 121.3777458193083, Iteration: 77
Current loss is: 121.3777458193083, Iteration: 78
Current loss is: 121.3777458193083, Iteration: 79
Current loss is: 121.3777458193083, Iteration: 80
Current loss is: 121.3777458193083, Iteration: 81
Current loss is: 121.3777458193083, Iteration: 82
Current loss is: 121.3777458193083, Iteration: 83
Current loss is: 121.3777458193083, Iteration: 84
Current loss is: 121.3777458193083, Iteration: 85
Current loss is: 121.3777458193083, Iteration: 86
Current loss is: 121.3777458193083, Iteration: 87
Current loss is: 121.3777458193083, Iteration: 88
Current loss is: 121.3777458193083, Iteration: 89
Current loss is: 121.3777458193083, Iteration: 90
Current loss is: 121.3777458193083, Iteration: 91
Current loss is: 121.3777458193083, Iteration: 92
Current loss is: 121.3777458193083, Iteration: 93
Current loss is: 121.3777458193083, Iteration: 94
Current loss is: 121.3777458193083, Iteration: 95
Current loss is: 121.3777458193083, Iteration: 96
Current loss is: 121.3777458193083, Iteration: 97
Current loss is: 121.3777458193083, Iteration: 98
Current loss is: 121.3777458193083, Iteration: 99
Current loss is: 121.3777458193083, Iteration: 100
Current loss is: 121.3777458193083, Iteration: 101
Current loss is: 121.3777458193083, Iteration: 102
Current loss is: 121.3777458193083, Iteration: 103
Current loss is: 121.3777458193083, Iteration: 104
Current loss is: 121.3777458193083, Iteration: 105
Current loss is: 121.3777458193083, Iteration: 106
Current loss is: 121.3777458193083, Iteration: 107
Current loss is: 121.3777458193083, Iteration: 108
Current loss is: 121.3777458193083, Iteration: 109
Current loss is: 121.3777458193083, Iteration: 110
Current loss is: 121.3777458193083, Iteration: 111
Current loss is: 121.3777458193083, Iteration: 112
Current loss is: 121.3777458193083, Iteration: 113
Current loss is: 121.3777458193083, Iteration: 114
Current loss is: 121.3777458193083, Iteration: 115
Current loss is: 121.3777458193083, Iteration: 116
Current loss is: 121.3777458193083, Iteration: 117
Current loss is: 121.3777458193083, Iteration: 118
Current loss is: 121.3777458193083, Iteration: 119
Current loss is: 121.3777458193083, Iteration: 120
Current loss is: 121.3777458193083, Iteration: 121
Current loss is: 121.3777458193083, Iteration: 122
Current loss is: 121.3777458193083, Iteration: 123
Current loss is: 121.3777458193083, Iteration: 124
Current loss is: 121.3777458193083, Iteration: 125
Current loss is: 121.3777458193083, Iteration: 126
Current loss is: 121.3777458193083, Iteration: 127
Current loss is: 121.3777458193083, Iteration: 128
Current loss is: 121.3777458193083, Iteration: 129
Current loss is: 121.3777458193083, Iteration: 130
Current loss is: 121.3777458193083, Iteration: 131
Current loss is: 121.3777458193083, Iteration: 132
Current loss is: 121.3777458193083, Iteration: 133
Current loss is: 121.3777458193083, Iteration: 134
Current loss is: 121.3777458193083, Iteration: 135
Current loss is: 121.3777458193083, Iteration: 136
Current loss is: 121.3777458193083, Iteration: 137
Current loss is: 121.3777458193083, Iteration: 138
Current loss is: 121.3777458193083, Iteration: 139
Current loss is: 121.3777458193083, Iteration: 140
Current loss is: 121.3777458193083, Iteration: 141
Current loss is: 121.3777458193083, Iteration: 142
Current loss is: 121.3777458193083, Iteration: 143
Current loss is: 121.3777458193083, Iteration: 144
Current loss is: 121.3777458193083, Iteration: 145
Current loss is: 121.3777458193083, Iteration: 146
Current loss is: 121.3777458193083, Iteration: 147
Current loss is: 121.3777458193083, Iteration: 148
Current loss is: 121.3777458193083, Iteration: 149
Current loss is: 121.3777458193083, Iteration: 150
Current loss is: 121.3777458193083, Iteration: 151
Current loss is: 121.3777458193083, Iteration: 152
Current loss is: 121.3777458193083, Iteration: 153
Current loss is: 121.3777458193083, Iteration: 154
Current loss is: 121.3777458193083, Iteration: 155
Current loss is: 121.3777458193083, Iteration: 156
Current loss is: 121.3777458193083, Iteration: 157
Current loss is: 121.3777458193083, Iteration: 158
Current loss is: 121.3777458193083, Iteration: 159
Current loss is: 121.3777458193083, Iteration: 160
Current loss is: 121.3777458193083, Iteration: 161
Current loss is: 121.3777458193083, Iteration: 162
Current loss is: 121.3777458193083, Iteration: 163
Current loss is: 121.3777458193083, Iteration: 164
Current loss is: 121.3777458193083, Iteration: 165
Current loss is: 121.3777458193083, Iteration: 166
Current loss is: 121.3777458193083, Iteration: 167
Current loss is: 121.3777458193083, Iteration: 168
Current loss is: 121.3777458193083, Iteration: 169
Current loss is: 121.3777458193083, Iteration: 170
Current loss is: 121.3777458193083, Iteration: 171
Current loss is: 121.3777458193083, Iteration: 172
Current loss is: 121.3777458193083, Iteration: 173
Current loss is: 121.3777458193083, Iteration: 174
Current loss is: 121.3777458193083, Iteration: 175
Current loss is: 121.3777458193083, Iteration: 176
Current loss is: 121.3777458193083, Iteration: 177
Current loss is: 121.3777458193083, Iteration: 178
Current loss is: 121.3777458193083, Iteration: 179
Current loss is: 121.3777458193083, Iteration: 180
Current loss is: 121.3777458193083, Iteration: 181
Current loss is: 121.3777458193083, Iteration: 182
Current loss is: 121.3777458193083, Iteration: 183
Current loss is: 121.3777458193083, Iteration: 184
Current loss is: 121.3777458193083, Iteration: 185
Current loss is: 121.3777458193083, Iteration: 186
Current loss is: 121.3777458193083, Iteration: 187
Current loss is: 121.3777458193083, Iteration: 188
Current loss is: 121.3777458193083, Iteration: 189
Current loss is: 121.3777458193083, Iteration: 190
Current loss is: 121.3777458193083, Iteration: 191
Current loss is: 121.3777458193083, Iteration: 192
Current loss is: 121.3777458193083, Iteration: 193
Current loss is: 121.3777458193083, Iteration: 194
Current loss is: 121.3777458193083, Iteration: 195
Current loss is: 121.3777458193083, Iteration: 196
Current loss is: 121.3777458193083, Iteration: 197
Current loss is: 121.3777458193083, Iteration: 198
Current loss is: 121.3777458193083, Iteration: 199
Current loss is: 121.3777458193083, Iteration: 200
Current loss is: 121.3777458193083, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
0.0
0.006315714
0.011539376
0.015674114
0.018725103
0.020699587
0.021606745
0.021457678
⋮
-0.007863174
-0.015861165
-0.024744594
-0.034488622
-0.04506795
-0.05645652
-0.06862798 The loss is constant and the NNODE is not getting trained. |
dt
is given with autodiff=true
autodiff=true
autodiff=true
autodiff=true
The Flux type conversion drops duals, so that's something to start with removing. Let's start by transforming everything to Lux first, clean up and delete code, then isolate. |
As the Flux removing is done - #789, I visited this back to see what was happening. With [email protected], julia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
Current loss is: 133.45653590504156, Iteration: 1
Current loss is: 133.45653590504156, Iteration: 2
Current loss is: 133.45653590504156, Iteration: 3
Current loss is: 133.45653590504156, Iteration: 4
Current loss is: 133.45653590504156, Iteration: 5
Current loss is: 133.45653590504156, Iteration: 6
Current loss is: 133.45653590504156, Iteration: 7
Current loss is: 133.45653590504156, Iteration: 8
Current loss is: 133.45653590504156, Iteration: 9
Current loss is: 133.45653590504156, Iteration: 10
Current loss is: 133.45653590504156, Iteration: 11
Current loss is: 133.45653590504156, Iteration: 12
Current loss is: 133.45653590504156, Iteration: 13
Current loss is: 133.45653590504156, Iteration: 14
Current loss is: 133.45653590504156, Iteration: 15
Current loss is: 133.45653590504156, Iteration: 16
Current loss is: 133.45653590504156, Iteration: 17
Current loss is: 133.45653590504156, Iteration: 18
Current loss is: 133.45653590504156, Iteration: 19
Current loss is: 133.45653590504156, Iteration: 20
Current loss is: 133.45653590504156, Iteration: 21
Current loss is: 133.45653590504156, Iteration: 22
Current loss is: 133.45653590504156, Iteration: 23
Current loss is: 133.45653590504156, Iteration: 24
Current loss is: 133.45653590504156, Iteration: 25
Current loss is: 133.45653590504156, Iteration: 26
Current loss is: 133.45653590504156, Iteration: 27
Current loss is: 133.45653590504156, Iteration: 28
Current loss is: 133.45653590504156, Iteration: 29
Current loss is: 133.45653590504156, Iteration: 30
Current loss is: 133.45653590504156, Iteration: 31
Current loss is: 133.45653590504156, Iteration: 32
Current loss is: 133.45653590504156, Iteration: 33
Current loss is: 133.45653590504156, Iteration: 34
Current loss is: 133.45653590504156, Iteration: 35
Current loss is: 133.45653590504156, Iteration: 36
Current loss is: 133.45653590504156, Iteration: 37
Current loss is: 133.45653590504156, Iteration: 38
Current loss is: 133.45653590504156, Iteration: 39
Current loss is: 133.45653590504156, Iteration: 40
Current loss is: 133.45653590504156, Iteration: 41
Current loss is: 133.45653590504156, Iteration: 42
Current loss is: 133.45653590504156, Iteration: 43
Current loss is: 133.45653590504156, Iteration: 44
Current loss is: 133.45653590504156, Iteration: 45
Current loss is: 133.45653590504156, Iteration: 46
Current loss is: 133.45653590504156, Iteration: 47
Current loss is: 133.45653590504156, Iteration: 48
Current loss is: 133.45653590504156, Iteration: 49
Current loss is: 133.45653590504156, Iteration: 50
Current loss is: 133.45653590504156, Iteration: 51
Current loss is: 133.45653590504156, Iteration: 52
Current loss is: 133.45653590504156, Iteration: 53
Current loss is: 133.45653590504156, Iteration: 54
Current loss is: 133.45653590504156, Iteration: 55
Current loss is: 133.45653590504156, Iteration: 56
Current loss is: 133.45653590504156, Iteration: 57
Current loss is: 133.45653590504156, Iteration: 58
Current loss is: 133.45653590504156, Iteration: 59
Current loss is: 133.45653590504156, Iteration: 60
Current loss is: 133.45653590504156, Iteration: 61
Current loss is: 133.45653590504156, Iteration: 62
Current loss is: 133.45653590504156, Iteration: 63
Current loss is: 133.45653590504156, Iteration: 64
Current loss is: 133.45653590504156, Iteration: 65
Current loss is: 133.45653590504156, Iteration: 66
Current loss is: 133.45653590504156, Iteration: 67
Current loss is: 133.45653590504156, Iteration: 68
Current loss is: 133.45653590504156, Iteration: 69
Current loss is: 133.45653590504156, Iteration: 70
Current loss is: 133.45653590504156, Iteration: 71
Current loss is: 133.45653590504156, Iteration: 72
Current loss is: 133.45653590504156, Iteration: 73
Current loss is: 133.45653590504156, Iteration: 74
Current loss is: 133.45653590504156, Iteration: 75
Current loss is: 133.45653590504156, Iteration: 76
Current loss is: 133.45653590504156, Iteration: 77
Current loss is: 133.45653590504156, Iteration: 78
Current loss is: 133.45653590504156, Iteration: 79
Current loss is: 133.45653590504156, Iteration: 80
Current loss is: 133.45653590504156, Iteration: 81
Current loss is: 133.45653590504156, Iteration: 82
Current loss is: 133.45653590504156, Iteration: 83
Current loss is: 133.45653590504156, Iteration: 84
Current loss is: 133.45653590504156, Iteration: 85
Current loss is: 133.45653590504156, Iteration: 86
Current loss is: 133.45653590504156, Iteration: 87
Current loss is: 133.45653590504156, Iteration: 88
Current loss is: 133.45653590504156, Iteration: 89
Current loss is: 133.45653590504156, Iteration: 90
Current loss is: 133.45653590504156, Iteration: 91
Current loss is: 133.45653590504156, Iteration: 92
Current loss is: 133.45653590504156, Iteration: 93
Current loss is: 133.45653590504156, Iteration: 94
Current loss is: 133.45653590504156, Iteration: 95
Current loss is: 133.45653590504156, Iteration: 96
Current loss is: 133.45653590504156, Iteration: 97
Current loss is: 133.45653590504156, Iteration: 98
Current loss is: 133.45653590504156, Iteration: 99
Current loss is: 133.45653590504156, Iteration: 100
Current loss is: 133.45653590504156, Iteration: 101
Current loss is: 133.45653590504156, Iteration: 102
Current loss is: 133.45653590504156, Iteration: 103
Current loss is: 133.45653590504156, Iteration: 104
Current loss is: 133.45653590504156, Iteration: 105
Current loss is: 133.45653590504156, Iteration: 106
Current loss is: 133.45653590504156, Iteration: 107
Current loss is: 133.45653590504156, Iteration: 108
Current loss is: 133.45653590504156, Iteration: 109
Current loss is: 133.45653590504156, Iteration: 110
Current loss is: 133.45653590504156, Iteration: 111
Current loss is: 133.45653590504156, Iteration: 112
Current loss is: 133.45653590504156, Iteration: 113
Current loss is: 133.45653590504156, Iteration: 114
Current loss is: 133.45653590504156, Iteration: 115
Current loss is: 133.45653590504156, Iteration: 116
Current loss is: 133.45653590504156, Iteration: 117
Current loss is: 133.45653590504156, Iteration: 118
Current loss is: 133.45653590504156, Iteration: 119
Current loss is: 133.45653590504156, Iteration: 120
Current loss is: 133.45653590504156, Iteration: 121
Current loss is: 133.45653590504156, Iteration: 122
Current loss is: 133.45653590504156, Iteration: 123
Current loss is: 133.45653590504156, Iteration: 124
Current loss is: 133.45653590504156, Iteration: 125
Current loss is: 133.45653590504156, Iteration: 126
Current loss is: 133.45653590504156, Iteration: 127
Current loss is: 133.45653590504156, Iteration: 128
Current loss is: 133.45653590504156, Iteration: 129
Current loss is: 133.45653590504156, Iteration: 130
Current loss is: 133.45653590504156, Iteration: 131
Current loss is: 133.45653590504156, Iteration: 132
Current loss is: 133.45653590504156, Iteration: 133
Current loss is: 133.45653590504156, Iteration: 134
Current loss is: 133.45653590504156, Iteration: 135
Current loss is: 133.45653590504156, Iteration: 136
Current loss is: 133.45653590504156, Iteration: 137
Current loss is: 133.45653590504156, Iteration: 138
Current loss is: 133.45653590504156, Iteration: 139
Current loss is: 133.45653590504156, Iteration: 140
Current loss is: 133.45653590504156, Iteration: 141
Current loss is: 133.45653590504156, Iteration: 142
Current loss is: 133.45653590504156, Iteration: 143
Current loss is: 133.45653590504156, Iteration: 144
Current loss is: 133.45653590504156, Iteration: 145
Current loss is: 133.45653590504156, Iteration: 146
Current loss is: 133.45653590504156, Iteration: 147
Current loss is: 133.45653590504156, Iteration: 148
Current loss is: 133.45653590504156, Iteration: 149
Current loss is: 133.45653590504156, Iteration: 150
Current loss is: 133.45653590504156, Iteration: 151
Current loss is: 133.45653590504156, Iteration: 152
Current loss is: 133.45653590504156, Iteration: 153
Current loss is: 133.45653590504156, Iteration: 154
Current loss is: 133.45653590504156, Iteration: 155
Current loss is: 133.45653590504156, Iteration: 156
Current loss is: 133.45653590504156, Iteration: 157
Current loss is: 133.45653590504156, Iteration: 158
Current loss is: 133.45653590504156, Iteration: 159
Current loss is: 133.45653590504156, Iteration: 160
Current loss is: 133.45653590504156, Iteration: 161
Current loss is: 133.45653590504156, Iteration: 162
Current loss is: 133.45653590504156, Iteration: 163
Current loss is: 133.45653590504156, Iteration: 164
Current loss is: 133.45653590504156, Iteration: 165
Current loss is: 133.45653590504156, Iteration: 166
Current loss is: 133.45653590504156, Iteration: 167
Current loss is: 133.45653590504156, Iteration: 168
Current loss is: 133.45653590504156, Iteration: 169
Current loss is: 133.45653590504156, Iteration: 170
Current loss is: 133.45653590504156, Iteration: 171
Current loss is: 133.45653590504156, Iteration: 172
Current loss is: 133.45653590504156, Iteration: 173
Current loss is: 133.45653590504156, Iteration: 174
Current loss is: 133.45653590504156, Iteration: 175
Current loss is: 133.45653590504156, Iteration: 176
Current loss is: 133.45653590504156, Iteration: 177
Current loss is: 133.45653590504156, Iteration: 178
Current loss is: 133.45653590504156, Iteration: 179
Current loss is: 133.45653590504156, Iteration: 180
Current loss is: 133.45653590504156, Iteration: 181
Current loss is: 133.45653590504156, Iteration: 182
Current loss is: 133.45653590504156, Iteration: 183
Current loss is: 133.45653590504156, Iteration: 184
Current loss is: 133.45653590504156, Iteration: 185
Current loss is: 133.45653590504156, Iteration: 186
Current loss is: 133.45653590504156, Iteration: 187
Current loss is: 133.45653590504156, Iteration: 188
Current loss is: 133.45653590504156, Iteration: 189
Current loss is: 133.45653590504156, Iteration: 190
Current loss is: 133.45653590504156, Iteration: 191
Current loss is: 133.45653590504156, Iteration: 192
Current loss is: 133.45653590504156, Iteration: 193
Current loss is: 133.45653590504156, Iteration: 194
Current loss is: 133.45653590504156, Iteration: 195
Current loss is: 133.45653590504156, Iteration: 196
Current loss is: 133.45653590504156, Iteration: 197
Current loss is: 133.45653590504156, Iteration: 198
Current loss is: 133.45653590504156, Iteration: 199
Current loss is: 133.45653590504156, Iteration: 200
Current loss is: 133.45653590504156, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
0.0
-0.026961738
-0.054784633
-0.08346676
-0.1130049
-0.14339462
-0.17463014
-0.20670454
-0.23960975
-0.27333638
-0.30787417
-0.3432115
-0.37933594
-0.41623402
-0.4538912
-0.49229237
-0.53142136
-0.5712613
-0.61179453
-0.6530033
-0.6948684 where the loss remains constant. But with [email protected], I get an error: ulia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
ERROR: MethodError: no method matching zero(::Type{ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{…}}}})
Closest candidates are:
zero(::Type{Union{}}, Any...)
@ Base number.jl:310
zero(::Type{Dates.Time})
@ Dates ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Dates/src/types.jl:440
zero(::Type{Pkg.Resolve.FieldValue})
@ Pkg ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Pkg/src/Resolve/fieldvalues.jl:38
...
Stacktrace:
[1] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}}, ::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}})
@ OptimizationZygoteExt ~/.julia/packages/Optimization/79XSq/ext/OptimizationZygoteExt.jl:93
[2] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[3] macro expansion
@ ~/.julia/packages/Optimization/79XSq/src/utils.jl:41 [inlined]
[4] __solve(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[5] solve!(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
@ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:179
[6] solve(::OptimizationProblem{true, OptimizationFunction{…}, ComponentArrays.ComponentVector{…}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::Optimisers.Adam; kwargs::@Kwargs{callback::NeuralPDE.var"#192#196"{…}, maxiters::Int64})
@ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:96
[7] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Float32, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
@ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:489
[8] __solve
@ ~/NeuralPDE.jl/src/ode_solve.jl:373 [inlined]
[9] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:609
[10] solve_call
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:567 [inlined]
[11] #solve_up#42
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1058 [inlined]
[12] solve_up
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1044 [inlined]
[13] #solve#40
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
[14] top-level scope
@ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types. which is because of SciML/Optimization.jl#679. This traces from the |
Instead of using ForwardDiff.jacobian, we could do the dual evaluation directly. It's the same as this: https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/src/derivative_wrappers.jl#L84-L103 T = typeof(ForwardDiff.Tag(NeuralPDETag(), eltype(t)))
tdual = Dual{T, eltype(df), 1}(t, ForwardDiff.Partials((one(typeof(t)),)))
first.(ForwardDiff.partials.(phi(tdual, θ))) and add a |
Ok, will try this. |
MWE:
Running one of the tests in
NNODE_tests.jl
,This works -
This errors out -
Stacktrace:
The text was updated successfully, but these errors were encountered: