-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consistency in the type behavior of restructure #95
Comments
I would actually be in favour of behaviour 3: Now, another tricky thing is what to do about structured array types. Here I think we just have to enumerate as many weird cases as we can think of and come to an agreement on how to handle them all consistently. One such example: julia> d = Dense(Diagonal(rand(Float32, 3)), false)
Dense(3 => 3; bias=false) # 9 parameters
julia> d.weight
3×3 Diagonal{Float32, Vector{Float32}}:
0.24043 ⋅ ⋅
⋅ 0.657887 ⋅
⋅ ⋅ 0.52947
julia> p, re = destructure(d)
( [1] = 0.24043
[5] = 0.657887
[9] = 0.52947, Restructure(Dense, ..., 9))
julia> p
9-element SparseArrays.SparseVector{Float32, Int64} with 3 stored entries:
[1] = 0.24043
[5] = 0.657887
[9] = 0.52947
julia> re(p)
Dense(3 => 3; bias=false) # 9 parameters
julia> re(p) |> dump
Dense{typeof(identity), Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}, Bool}
weight: Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}
diag: SparseArrays.SparseVector{Float32, Int64}
n: Int64 3
nzind: Array{Int64}((3,)) [1, 2, 3]
nzval: Array{Float32}((3,)) Float32[0.24042994, 0.6578865, 0.52947]
bias: Bool false
σ: identity (function of type typeof(identity)) And another one: julia> d = Dense(rand(Float32, 3, 2), @SArray ones(3))
Dense(2 => 3) # 9 parameters
julia> p, re = destructure(d)
(Float32[0.9659148, -0.7210188, 0.20607175, 0.7583495, 0.35627228, -0.5444089, 0.0, 0.0, 0.0], Restructure(Dense, ..., 9))
julia> re(p)
Dense(2 => 3) # 9 parameters
julia> re(p) |> dump
Dense{typeof(identity), Matrix{Float32}, SizedVector{3, Float32, Vector{Float32}}}
weight: Array{Float32}((3, 2)) Float32[0.9659148 0.7583495; -0.7210188 0.35627228; 0.20607175 -0.5444089]
bias: SizedVector{3, Float32, Vector{Float32}}
data: Array{Float32}((3,)) Float32[0.0, 0.0, 0.0]
σ: identity (function of type typeof(identity))
julia> cu_p = cu(p)
9-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.9659148
-0.7210188
0.20607175
0.7583495
0.35627228
-0.5444089
0.0
0.0
0.0
julia> re(cu_p) |> dump
Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}
weight: CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
buffer: CUDA.Mem.DeviceBuffer
ctx: CuContext
handle: Ptr{Nothing} @0x0000000002ab0400
valid: Bool true
ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0800)
bytesize: Int64 24
async: Bool false
refcount: Base.Threads.Atomic{Int64}
value: Int64 1
maxsize: Int64 24
offset: Int64 0
dims: Tuple{Int64, Int64}
1: Int64 3
2: Int64 2
bias: SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}
data: CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
buffer: CUDA.Mem.DeviceBuffer
ctx: CuContext
handle: Ptr{Nothing} @0x0000000002ab0400
valid: Bool true
ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0a00)
bytesize: Int64 12
async: Bool false
refcount: Base.Threads.Atomic{Int64}
value: Int64 1
maxsize: Int64 12
offset: Int64 0
dims: Tuple{Int64}
1: Int64 3
σ: identity (function of type typeof(identity)) |
I repeat that no incorrect gradients have been displayed here. Calling other features you happen to dislike in some context gradient bugs is just muddying the waters. (There are known gradient bugs, they are marked "bug" in the issues here.) Maybe it's helpful to understand what the goals are of the present design:
For 3., you may recall that #66 was precisely to address your complaint that Since ReverseDiff.jl also likes flat arrays not nested trees, the same should go for its tracked arrays. If they don't propagate, I think that's a bug. But no need to guess. Tracker's arrays seem to work fine, something seems to make At present, this package does not know about GPU arrays, and thus makes no distinctions. If you think it's confusing that |
Re structured arrays, I suspect most of them should be marked
This discards |
I don't claim to know what the right answer is, so I posted those examples because it's not clear if they'd be considered consistent enough to pass muster. Another one is On a meta level, I feel even more strongly now that the behaviour of |
Ok. Adjoint should now reconstruct:
I agree that things are a bit under-specified. Like everything else in Julia really -- it's a bit of an exploration to see what properties turn out to be useful, and how to compose them. |
I don't disagree. There are no incorrect gradients here by the definition now in the docs. It's just an issue that only presents itself to downstream users via incorrect gradients (as demonstrated) in functions which expect to have the normal action that a generic Julia function generally has. It's a very subtle distinction. I agree it's not incorrect as documented, but it is also very hard to spot that it's happening in most cases (with demonstrations as to why) |
For v0.4 we should think if there is something we want to change in the behavior of |
This was discovered in SciML/NeuralPDE.jl#533 as an issue that only showed itself as an incorrect gradient: the primal passes of what was being trained was in Float64, the reverse passes gave a Float64, the loss function print out give a Float64, and everything looked fine, except magically the Flux neural network was just "a bit more janky", in that it had a much higher probability of failing CI tests for a reason nobody could figure out for 5 months. Finally it was discovered that parts of the gradient were calculated in Float32 because the Flux.Chain had Float32 parameters in there. This showcased that
re(p)
does not "always" respect the types ofp
.But it doesn't "always" respect the types of the Flux.Chain either. For example, for a standard Flux.Chain of Dense layers with Float32 parameters, you get:
re(p::Vector{Float64})
computes in Float32re(p::CuVector{Float32})
computes on the GPU in Float32re(p::Vector{Dual})
computes with Dual numbersre(p::Vector{ComplexF32})
computes with Float32And now let's have some fun:
re(p::CuVector{Float64})
computes ???. My guess is CuVector{Float32}?re(p::ReverseDiff.TrackedArray)
computes ??? My guess is Array{TrackedReal{Float32}}?I understand that this isn't intended behavior and comes out of some quirks about
ProjectTo
, that exposes some (IMO odd) behaviors of a ChainRules internal to users who are likely not experts in the autodiff system.Now the problem that I have with it is that discovering this behavior is rather hard, because if you do anything other than the simplest "just use the neural network", almost any case will not expose to the user that this behavior exists. For example,
(p[end] .* re(p))::typeof(p)
(p[end] .+ re(p))::typeof(p)
so hold in the examples I described because the type demotion is countered by the type promotion that's applied by essentially any other computation that uses things with the
eltype(p)
. Thus unlessre(p)
is the only operation that is used (in which case, you probably don't need to be using restructure/destructure), some other operation in the primal will mask the demotion and your forward pass will look like it computed usingtypeof(p)
. It will only present itself to a user in the gradient pass.Thus I understand @mcabbott's reasoning behind saying it's not a gradient correctness issue (since it's correctly calculating the gradients of the object that is actually reconstructed), but I have now isolated many different cases that I thought were just "Flux janky behavior" and "I don't know why FastChain works here but Flux.Chain doesn't" all back to this same behavior. It may not be a gradient correctness issue, but it only presents itself as one in downstream libraries where I have found this, it only really exposes itself if you try to look into a seemingly incorrect gradient, and if it quacks like 🦆?
I understand that this behavior is now documented, but I'm not sure a behavior that presents itself like that is sufficiently handled just by documentation because it's hard to even figure out that something is going wrong without investigating the gradient calculation.
What could be done?
I would propose that we should just make the behavior undeniably straightforward and consistent. Either always make
re(p)
compute using values oftypeof(p)
, or make it so it always computes using the values from the original Flux.Chain. Either choice is an easily explainable and predictable behavior. This middle ground is not easy to explain or predict.Always matching
p
is the more predictable behavior in the Julia ecosystem. If you stick a complex number as the initial condition in the ODE solver, as the initial guess for a value in Optim, as the starting point for IterativeSolvers or NLsolve, etc. any generic code that I can think of, they will treat the computation in the sense thatp
provides. In many cases generic codes will just error if they can't handle it, but they try to compute usingp
. Non-generic codes immediately throw method errors describing what the allowed inputs are. I cannot think of another example in the Julia ecosystem where the "computation type" forf(p)
does not matchp
or a fixed type, but instead match the internal types of the fields off
, only sometimes, other times it matchesp
.If it always matches the Flux.Chain, at least that would be clearly visible since when you do it on a CuArray you see you get an Array and you're like oh, I see how this works. If I want to GPU, then I
|> gpu
the chain because it doesn't convert top
. Got it. With the current behavior, you see itre(p)
works on the GPU, so okay why not just dore(p::Array{Float64})
as a quick way to convert to Float64? And if you think like that, you get burned.The other behavior could be to throw an error in any case where a type conversion is necessary. If you want
re(p::Array{Float64})
to work, go back and|> f64
the neural network. Now, this will cause some issues with making libraries work, but it's a nice (overly) safe option that would ensure there are no surprises.Or, as @ToucheSir suggested, maybe these are two different functions, or two different options, and you should be required to choose which behavior you want. Some kind of
re(p,Optimisers.NoConvert())
andre(p,Optimisers.Convert())
.Those 4 behaviors would be clear and easily predictable. I think the only option I would be adamantly against is the current behavior.
The text was updated successfully, but these errors were encountered: