-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient dimension mismatch error when training rnns #1891
Comments
I suspect this has something to do with LSTMs not accepting 1D inputs, but have not dug into the code. To be on the safe side, you should always make sure to have a batch dimension for inputs. Currently each element of |
The size of each element of |
You are correct and I definitely hallucinated in some non-existent commas. If you don't need to have a learnable initial state, the easiest way to avoid this is to call Interestingly, using explicit params results in the correct shape: function loss2(m, xs, ys)
Flux.reset!(m)
sum(map((x, y) -> Flux.mse(m(x), y), xs, ys))
end
julia> gradient(m -> loss2(m, xs, ys), m)
((cell = (σ = nothing, Wi = Float32[-14.8857;;], Wh = Float32[-4.192739;;], b = Float32[-2.7642984], st
ate0 = Float32[-7.955205;;]), state = nothing),) It's not immediately clear to me what the issue with implicit params could be since I know people are using |
Hello is there any update regarding this issue? I am getting a similar error on the following
Here is the error that I get: ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 9 and 32 |
After #2500 the example becomes using Flux
struct MyModel{G, V}
gru::G
h0::V
end
Flux.@layer MyModel
MyModel(d::Int) = MyModel(GRU(d => d), zeros(d))
function loss(model, xs, ys)
ŷs = model.gru(xs, model.h0)
return Flux.mse(ŷs, ys)
end
xs = rand(Float32, 1, 2, 3) # d x len x batch_size
ys = rand(Float32, 1, 2, 3) # d x len x batch_size
model = MyModel(1)
opt_state = Flux.setup(Adam(1e-4), model)
grad = gradient(m -> loss(m, xs, ys), model)[1]
Flux.update!(opt_state, model, grad) and works as expected |
This happens when using
RNN
orGRU
but doesn't when usingLSTM
The text was updated successfully, but these errors were encountered: