You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import Flux
import Zygote
using Functors
struct Test
a
b
end@functor Test
function (m::Test)(x)
a = x
for f=m.a
a =f(a)
end
b = x
for f=m.b
b =f(b)
end
a + b
end
t =Test([Flux.Dense(10=>5)], [Flux.Dense(10=>5)])
x =rand(10)
Zygote.gradient(() ->sum(t(x)), Flux.params(t))
The error on for f=m.b is
ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 5 and 10")
Stacktrace:
[1] _bcs1
@ ./broadcast.jl:516 [inlined]
[2] _bcs
@ ./broadcast.jl:510 [inlined]
[3] broadcast_shape
@ ./broadcast.jl:504 [inlined]
[4] combine_axes
@ ./broadcast.jl:499 [inlined]
[5] instantiate
@ ./broadcast.jl:281 [inlined]
[6] materialize
@ ./broadcast.jl:860 [inlined]
[7] accum(x::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ys::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:25
[8] Pullback
@ repro.jl:18 [inlined]
[9] (::typeof(∂(λ)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[10] Pullback
@ repro.jl:26 [inlined]
[11] (::typeof(∂(#3)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[12] (::Zygote.var"#97#98"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#3)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:357
[13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
[14] top-level scope
@ repro.jl:26
in expression starting at repro.jl:26
Even though all the dimensions actually match up correctly. Running t(x) is just fine.
The error goes away with any one of the following:
Replace a + b with just a or just b
Replace a = x with a = copy(x) (this is the workaround I'm using in my actual code right now)
Replace both for loops with explicit indexing
This is with Zygote version 0.6.41
The text was updated successfully, but these errors were encountered:
Here is the minimal reproducer I came up with
The error on
for f=m.b
isEven though all the dimensions actually match up correctly. Running
t(x)
is just fine.The error goes away with any one of the following:
a + b
with justa
or justb
a = x
witha = copy(x)
(this is the workaround I'm using in my actual code right now)This is with Zygote version 0.6.41
The text was updated successfully, but these errors were encountered: