Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An unknown error occurs raised by Zygote.jl when using eachslice and NamedTuple #1540

Closed
chooron opened this issue Dec 5, 2024 · 1 comment

Comments

@chooron
Copy link

chooron commented Dec 5, 2024

Hello, in my programming design, I need an operation for data storage. I want to store a three-dimensional array as a Vector{NamedTuple} type and then use it for computations. The simplified and reproducible computation process is as follows:

using Zygote

symbols = [Symbol("a_$i") for i in 1:8]
function loss(p)
    input = ones(2, 3, 10)
    data = [ones(3, 3, 10), ones(1, 3, 10), ones(2, 3, 10)]
    for d in data
        input = cat(input, d .* p, dims=1)
    end
    tmp = [NamedTuple{Tuple(symbols)}(eachslice(input[:, i, :], dims=1)) for i in 1:3]
    sum(tmp[1][:a_1])
end

Zygote.gradient(loss, 2.0)

After running the program, the following error occurred:

ERROR: MethodError: no method matching +(::@NamedTuple{contents::Array{Float64, 3}}, ::Base.RefValue{Any})

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:587
  +(::ChainRulesCore.NoTangent, ::Any)
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_arithmetic.jl:59
  +(::Any, ::MutableArithmetics.Zero)
   @ MutableArithmetics D:\Julia\Julia-1.10.4\packages\packages\MutableArithmetics\BLlgj\src\rewrite.jl:65
  ...

Stacktrace:
 [1] accum(x::@NamedTuple{contents::Array{Float64, 3}}, y::Base.RefValue{Any})
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nyzjS\src\lib\lib.jl:17
 [2] accum(x::@NamedTuple{contents::Array{Float64, 3}}, y::Nothing, zs::Base.RefValue{Any})
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nyzjS\src\lib\lib.jl:22
 [3] loss
   @ e:\JlCode\HydroModels\docs\tutorials\bug\zygote_1.jl:8 [inlined]
 [4] (::Zygote.Pullback{Tuple{typeof(loss), Float64}, Any})(Δ::Float64)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [5] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{typeof(loss), Float64}, Any}})(Δ::Float64)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nyzjS\src\compiler\interface.jl:91
 [6] gradient(f::Function, args::Float64)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nyzjS\src\compiler\interface.jl:148
 [7] top-level scope
   @ e:\JlCode\HydroModels\docs\tutorials\bug\zygote_1.jl:14

May I ask which part of my code does not comply with the standards?

st Zygote
[e88e6eb3] Zygote v0.6.73
@chooron
Copy link
Author

chooron commented Dec 6, 2024

Solved by using: [NamedTuple{Tuple(symbols)}(eachslice(input_ , dims=1)) for input_ in eachslice(input,dims=2)]

@chooron chooron closed this as completed Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant