-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some fast paths + type fixes #2137
Conversation
src/layers/normalise.jl
Outdated
γ = reshape(l.γ, affine_shape) | ||
β = reshape(l.β, affine_shape) | ||
return l.λ.(γ .* o .+ β) | ||
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's unfused by Zygote anyhow, might as well do that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For just the forward pass, it was still faster to un-fuse this, to do inv & sqrt N times not N^3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't that what your comment is saying? I might be misunderstanding, does "un-fuse" here refer to extracting s
as its own variable or to writing s = inv.(sqrt.(σ² .+ l.ϵ))
instead of s = (inv∘sqrt).(σ² .+ l.ϵ)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes maybe we are agreeing. The comment was meant to answer "why make s
at all", since without it things got slower. (inv∘sqrt)
is probably premature optimisation.
src/layers/normalise.jl
Outdated
return l.λ.(γ .* o .+ β) | ||
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)? | ||
if hasaffine(l) | ||
γ = reshape(l.γ, affine_shape) # ideally reshape on construction, store Scale? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with packing the affine params/activation in a Scale
is that batchnorm functions in 3rd party backends (notably cuDNN) expect them to be passed in alongside all the other params. Thus the NNlib-level API has to be batchnorm(x, ..., γ, β)
, so the Scale
only exists as a container to hold the affine params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. We should probably still make these arrays the size required on construction, and make them even if they won't be used, instead of this:
https://github.com/FluxML/NNlibCUDA.jl/blob/master/src/cudnn/batchnorm.jl#L21
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Flux ever call that NNlib code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the only remaining CUDA.jl-reliant functionality left in this repo aside from the Functors stuff: https://github.com/FluxML/Flux.jl/blob/master/src/cuda/cudnn.jl. Absolute kludge as you can see, which is why these routines should be moved to NNlib sooner than later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, I forgot about that file. But I remember seeing it when trying to remove CUDA... agree that NNlib is the right place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a quick blame of the NNlibCUDA line above and came up with FluxML/NNlibCUDA.jl#36. I don't recall why the arrays are allocated instead of just set as CU_NULL
before the call. The cuDNN docs don't mention bias and scale params can be null, so maybe that's why. If it turns out they can be and it's just not documented though, we should revisit this.
reduce_dims = [1:N-2; N] | ||
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} | ||
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))") | ||
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change hits the following failure:
julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=1)), [1 2; 3 4.0])
4×4 Matrix{Float64}:
1.24259 2.87677 0.0 0.0
2.87677 8.91398 0.0 0.0
0.0 0.0 1.33701 4.35217
0.0 0.0 4.35217 7.86527
julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=[1])), [1 2; 3 4.0])
4×4 Matrix{Float64}:
1.24259 2.87677 0.0 0.0
2.87677 8.91398 0.0 0.0
0.0 0.0 1.33701 4.35217
0.0 0.0 4.35217 7.86527
julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=(1,))), [1 2; 3 4.0])
ERROR: Mutating arrays is not supported -- called push!(Vector{Int64}, ...)
Stacktrace:
[3] (::Zygote.var"#397#398"{Vector{Int64}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
[4] (::Zygote.var"#2529#back#399"{Zygote.var"#397#398"{Vector{Int64}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[5] unique
@ ./set.jl:176 [inlined]
[6] (::typeof(∂(unique)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[7] _denom
@ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:7 [inlined]
[8] (::typeof(∂(_denom)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[9] #rrule#1801
@ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:13 [inlined]
[10] (::typeof(∂(#rrule#1801)))(Δ::Tuple{Matrix{Float64}, NamedTuple{(:n, :sum_pullback), Tuple{Float64, Nothing}}})
So it's differentiating this:
and differentiating the rule for unique
, which doesn't handle this case.
Zygote differentiates so many things it need not touch, surely this adds startup time... you only notice when it fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of its fatal flaws, you might say. Usually first-order differentiation is well-behaved because control flow and possible mutation are hidden away, but all bets are off with second order...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even first order, I think it does a lot which it need not do. Just most of the resulting errors have already been found. Same thing in trying out Diffractor -- lots of errors from obscure code calculating indices for views or whatever, to a human obviously non-diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one fixed in JuliaDiff/ChainRules.jl#687
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's one definite benefit of tracing/overload-based ADs. Anything not numerically interesting gets ignored or falls away in the final tape/graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. I presume that any kind of activity tracking would also let you eliminate most off-track things. Maybe declaring integers (and all structs not containing floats) non-diff would also help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's certainly a lot we could learn from projects like differentiable Swift (which uses activity analysis). It seems unlikely Zygote will be where such knowledge is applied given how poorly integrated it is with the compiler.
8e9a5cc adds also some explicit conversions on mismatched types. It will complain about Float64 input and then convert down (this is surely a mistake) but silently convert integers (unlikely to arise by mistake, but generic matmul is slow): julia> model = Chain(Dense(784 => 1000, relu), Dense(1000, 10));
julia> x = rand(0:1, 784, 1000);
julia> @btime $model($x);
min 482.107 ms, mean 483.866 ms (11 allocations, 7.73 MiB) # before
min 2.589 ms, mean 3.697 ms (10 allocations, 10.70 MiB) # after
julia> x64 = rand(784, 1000);
julia> @btime $model($x64) |> typeof
min 9.153 ms, mean 90.407 ms (12 allocations, 21.47 MiB)
Matrix{Float64} (alias for Array{Float64, 2}) # before
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow
│ layer = Dense(784 => 1000, relu) # 785_000 parameters
│ summary(x) = "784×1000 Matrix{Float64}"
└ @ Flux ~/.julia/dev/Flux/src/layers/basic.jl:191
min 2.596 ms, mean 3.857 ms (13 allocations, 10.70 MiB)
Matrix{Float32} (alias for Array{Float32, 2}) # after Does not affect deliberate use with Float64 parameters. Edit: fcbc7b4 upgrades to a broader implementation, inserting This will now by default assume the type of the weights is the type of the output you wanted. It's contrary to Base Julia, but actually seems like a good policy for Flux. If you want Float64, you should do It seems a bit odd to make so much of this as a performance trap in the docs, when in fact we can just automatically detect it, and fix most of the slowdown. |
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))") | ||
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might as well take the opportunity to mark these lines and the definition of affine_shape
below as ignored. BN causes a decent amount of Zygote compilation latency, so hiding anything that doesn't need to go through AD seems reasonable.
y = m(x) | ||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) | ||
@inferred m(x) | ||
@inferred m(x) # fails when x::Matrix{Float64}, do we care? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know why this fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not. Checking branches if !_isactive(l) && l.track_stats
I get the same types on all paths.
I don't know, but we can get some idea by running the downstream tests. Label now added and will trigger with the next commit. |
Maybe this should be split into two.
|
This adds fast paths for some layers with neither bias nor a nonlinearity:
Zygote knows to ignore
identity.(x .+ false)
, so the only real gradient change is to BatchNorm.