-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved type stability with explicit params #1248
Conversation
6ccf008
to
7540bd6
Compare
Ok, now we're back to the known failures on Nightly and downstream NeuralPDE. Molly one appears to be intermittent (minor numerical error) and showed up recently in PRs from a few days ago, but that should be investigated separately. One unexpected find while working on this PR is that Zygote and downstream packages were calling |
This doesn't sound crazy, but FWIW I do not see the same speedup:
(Julia master, M1 mac.) |
That may be because |
Avoiding globals makes a surprisingly large difference here:
|
Perhaps a newer Julia version helps close the gap, but I'm consistently seeing this ~15s difference on 1.7.3. This is the full script I'm using: using Flux
channels = 4
function resblock(channels)
return SkipConnection(Chain(
Conv((3, 3), channels => channels, pad=1),
Conv((3, 3), channels => channels, pad=1),
), +)
end
model = Chain(
SkipConnection(
Chain(
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
),
+),
AdaptiveMeanPool((1, 1))
)
@show typeof(model)
loss(m, x) = sum(m(x))
lr_images = randn(Float32, 2, 2, channels, 1)
@time loss(model, lr_images)
@time loss(model, lr_images)
loss_grad(m, x) = gradient(m -> loss(m, x), m)
# This gives the same numbers:
# loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)
@time loss_grad(model, lr_images)
@time loss_grad(model, lr_images) |
We can disable accumulating (implicit) parameters to the gradient cache in explicit mode. This can dramatically improve type stability because `accum_param` will return a `Union{Nothing, [grad type]}` otherwise.
8db910e
to
e9a6075
Compare
Co-authored-by: Michael Abbott <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it.
Should be marked as closing #1243?
We can disable accumulating (implicit) parameters to the gradient cache in explicit mode. This can dramatically improve type stability because
accum_param
will return aUnion{Nothing, [grad type]}
otherwise.One impact of this PR is that taking gradients of functions with both implicit and explicit parameters (i.e. calling
pullback
twice) may involve some additional compilation. However, given that we're trying to move users off of using implicit params anyhow, I see it as a small price to pay for being friendlier to the compiler.Benchmarking TTFG on the MWE in #1126, modified to use explicit params:
Closes #1243.