You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After much head-scratching about why my derivatives weren't coming out right, I discovered that weird things happen when I build new arrays by indexing from existing arrays. Sometimes, the derivatives are omitted completely:
gb = param(1,1; atype=Array{Float32})
function loss_test(b)
loss = [b[1]]
return sum(loss)
end
∇b = @diff loss_test(gb)
grad(∇b, gb) #Sparse(Matrix{Float32}(1,1)()) # should be [1.0;;]
Sometimes, they overwrite the derivatives of arrays they are added to:
gb = param(2,2; atype=Array{Float32})
function loss_test(bb)
loss = sum(bb, dims=2)
b = bb.*1
loss += [b[i] for i=1:2]
return sum(loss)
end
∇b = @diff loss_test(gb)
grad(∇b, gb) #[1.0 0.0; 1.0 0.0] # should be [2.0 1.0; 2.0 1.0]
If this kind of construction is known to be disallowed, then trying to do it should raise an error. If it's not known to be disallowed, then this is a bug.
The text was updated successfully, but these errors were encountered:
After much head-scratching about why my derivatives weren't coming out right, I discovered that weird things happen when I build new arrays by indexing from existing arrays. Sometimes, the derivatives are omitted completely:
Sometimes, they overwrite the derivatives of arrays they are added to:
If this kind of construction is known to be disallowed, then trying to do it should raise an error. If it's not known to be disallowed, then this is a bug.
The text was updated successfully, but these errors were encountered: