Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fast paths + type fixes #2137

Closed
wants to merge 5 commits into from
Closed

Conversation

mcabbott
Copy link
Member

This adds fast paths for some layers with neither bias nor a nonlinearity:

julia> using Flux

julia> @btime $(Dense(100 => 100; bias=false))($(randn(Float32, 100, 100)));
  min 6.917 μs, mean 243.044 μs (4 allocations, 78.22 KiB)  # before
  min 4.931 μs, mean 103.005 μs (2 allocations, 39.11 KiB)  # after

julia> @btime $(BatchNorm(3))($(randn(Float32, 100, 3, 100)));
  min 38.125 μs, mean 752.368 μs (19 allocations, 235.19 KiB)
  min 28.167 μs, mean 410.862 μs (18 allocations, 118.02 KiB)

julia> @btime $(Conv((3,3),3=>3,bias=false))($(randn(Float32, 28, 28, 3, 100)));
  min 313.708 μs, mean 5.634 ms (57 allocations, 1.83 MiB)
  min 272.875 μs, mean 2.885 ms (54 allocations, 1.06 MiB)

Zygote knows to ignore identity.(x .+ false), so the only real gradient change is to BatchNorm.

julia> @btime gradient((m,x) -> sum(abs2, m(x)), $(Dense(100 => 100; bias=false)),$(randn(Float32, 100, 100)));
  min 18.625 μs, mean 562.410 μs (31 allocations, 157.92 KiB)  # before
  min 18.792 μs, mean 594.045 μs (31 allocations, 157.44 KiB)  # after, just noise

julia> @btime gradient((m,x) -> sum(abs2, m(x)), $(BatchNorm(3)),$(randn(Float32, 100, 3, 100)));
  min 166.833 μs, mean 5.930 ms (571 allocations, 1.51 MiB)
  min 160.583 μs, mean 4.740 ms (586 allocations, 1.17 MiB)  # after, real improvement?

julia> @btime gradient((m,x) -> sum(abs2, m(x)), $(Conv((3,3),3=>3,bias=false)),$(randn(Float32, 28, 28, 3, 100)));
  min 2.503 ms, mean 12.723 ms (163 allocations, 3.09 MiB)
  min 2.472 ms, mean 13.070 ms (163 allocations, 3.09 MiB)  # after, just noise

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's unfused by Zygote anyhow, might as well do that here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For just the forward pass, it was still faster to un-fuse this, to do inv & sqrt N times not N^3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that what your comment is saying? I might be misunderstanding, does "un-fuse" here refer to extracting s as its own variable or to writing s = inv.(sqrt.(σ² .+ l.ϵ)) instead of s = (inv∘sqrt).(σ² .+ l.ϵ)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes maybe we are agreeing. The comment was meant to answer "why make s at all", since without it things got slower. (inv∘sqrt) is probably premature optimisation.

return l.λ.(γ .* o .+ β)
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)?
if hasaffine(l)
γ = reshape(l.γ, affine_shape) # ideally reshape on construction, store Scale?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with packing the affine params/activation in a Scale is that batchnorm functions in 3rd party backends (notably cuDNN) expect them to be passed in alongside all the other params. Thus the NNlib-level API has to be batchnorm(x, ..., γ, β), so the Scale only exists as a container to hold the affine params.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We should probably still make these arrays the size required on construction, and make them even if they won't be used, instead of this:

https://github.com/FluxML/NNlibCUDA.jl/blob/master/src/cudnn/batchnorm.jl#L21

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Flux ever call that NNlib code?

Copy link
Member

@ToucheSir ToucheSir Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the only remaining CUDA.jl-reliant functionality left in this repo aside from the Functors stuff: https://github.com/FluxML/Flux.jl/blob/master/src/cuda/cudnn.jl. Absolute kludge as you can see, which is why these routines should be moved to NNlib sooner than later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I forgot about that file. But I remember seeing it when trying to remove CUDA... agree that NNlib is the right place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a quick blame of the NNlibCUDA line above and came up with FluxML/NNlibCUDA.jl#36. I don't recall why the arrays are allocated instead of just set as CU_NULL before the call. The cuDNN docs don't mention bias and scale params can be null, so maybe that's why. If it turns out they can be and it's just not documented though, we should revisit this.

reduce_dims = [1:N-2; N]
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))")
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change hits the following failure:

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=1)), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=[1])), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=(1,))), [1 2; 3 4.0])
ERROR: Mutating arrays is not supported -- called push!(Vector{Int64}, ...)
Stacktrace:
  [3] (::Zygote.var"#397#398"{Vector{Int64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
  [4] (::Zygote.var"#2529#back#399"{Zygote.var"#397#398"{Vector{Int64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] unique
    @ ./set.jl:176 [inlined]
  [6] (::typeof((unique)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [7] _denom
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:7 [inlined]
  [8] (::typeof((_denom)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [9] #rrule#1801
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:13 [inlined]
 [10] (::typeof((#rrule#1801)))(Δ::Tuple{Matrix{Float64}, NamedTuple{(:n, :sum_pullback), Tuple{Float64, Nothing}}})

So it's differentiating this:

https://github.com/JuliaDiff/ChainRules.jl/blob/9a405f732758552cd945a110adb6828a997887a8/src/rulesets/Statistics/statistics.jl#L7

and differentiating the rule for unique, which doesn't handle this case.

Zygote differentiates so many things it need not touch, surely this adds startup time... you only notice when it fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of its fatal flaws, you might say. Usually first-order differentiation is well-behaved because control flow and possible mutation are hidden away, but all bets are off with second order...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even first order, I think it does a lot which it need not do. Just most of the resulting errors have already been found. Same thing in trying out Diffractor -- lots of errors from obscure code calculating indices for views or whatever, to a human obviously non-diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one fixed in JuliaDiff/ChainRules.jl#687

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one definite benefit of tracing/overload-based ADs. Anything not numerically interesting gets ignored or falls away in the final tape/graph.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I presume that any kind of activity tracking would also let you eliminate most off-track things. Maybe declaring integers (and all structs not containing floats) non-diff would also help.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's certainly a lot we could learn from projects like differentiable Swift (which uses activity analysis). It seems unlikely Zygote will be where such knowledge is applied given how poorly integrated it is with the compiler.

@mcabbott
Copy link
Member Author

mcabbott commented Dec 23, 2022

8e9a5cc adds also some explicit conversions on mismatched types. It will complain about Float64 input and then convert down (this is surely a mistake) but silently convert integers (unlikely to arise by mistake, but generic matmul is slow):

julia> model = Chain(Dense(784 => 1000, relu), Dense(1000, 10));

julia> x = rand(0:1, 784, 1000);

julia> @btime $model($x);
  min 482.107 ms, mean 483.866 ms (11 allocations, 7.73 MiB)  # before
  min 2.589 ms, mean 3.697 ms (10 allocations, 10.70 MiB)     # after

julia> x64 = rand(784, 1000);

julia> @btime $model($x64) |> typeof
  min 9.153 ms, mean 90.407 ms (12 allocations, 21.47 MiB)
Matrix{Float64} (alias for Array{Float64, 2})  # before

┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow
│   layer = Dense(784 => 1000, relu)  # 785_000 parameterssummary(x) = "784×1000 Matrix{Float64}"
└ @ Flux ~/.julia/dev/Flux/src/layers/basic.jl:191
  min 2.596 ms, mean 3.857 ms (13 allocations, 10.70 MiB)
Matrix{Float32} (alias for Array{Float32, 2})  # after

Does not affect deliberate use with Float64 parameters.

Edit: fcbc7b4 upgrades to a broader implementation, inserting _match_eltype in many places. What thoughts?

This will now by default assume the type of the weights is the type of the output you wanted. It's contrary to Base Julia, but actually seems like a good policy for Flux. If you want Float64, you should do |> f64. Will this policy break anything people wanted?

It seems a bit odd to make so much of this as a performance trap in the docs, when in fact we can just automatically detect it, and fix most of the slowdown.

@mcabbott mcabbott changed the title Some fast paths Some fast paths + type fixes Dec 23, 2022
Comment on lines +360 to +361
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))")
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well take the opportunity to mark these lines and the definition of affine_shape below as ignored. BN causes a decent amount of Zygote compilation latency, so hiding anything that doesn't need to go through AD seems reasonable.

y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
@inferred m(x)
@inferred m(x) # fails when x::Matrix{Float64}, do we care?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why this fails?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not. Checking branches if !_isactive(l) && l.track_stats I get the same types on all paths.

@ToucheSir
Copy link
Member

Will this policy break anything people wanted?

I don't know, but we can get some idea by running the downstream tests. Label now added and will trigger with the next commit.

@mcabbott mcabbott mentioned this pull request Jan 5, 2023
2 tasks
@mcabbott mcabbott marked this pull request as draft January 5, 2023 16:36
@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

Maybe this should be split into two.

This was referenced Jan 5, 2023
@mcabbott mcabbott closed this Jan 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants